From 2ea2bc2941321317de3cb5f4365463599d95dfe5 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 24 Jul 2025 15:22:35 +0300 Subject: [PATCH] converted nodes files starting with "t" letter --- comfy_extras/v3/nodes_tcfg.py | 70 +++ comfy_extras/v3/nodes_tomesd.py | 190 +++++++ comfy_extras/v3/nodes_torch_compile.py | 32 ++ comfy_extras/v3/nodes_train.py | 658 +++++++++++++++++++++++++ nodes.py | 4 + 5 files changed, 954 insertions(+) create mode 100644 comfy_extras/v3/nodes_tcfg.py create mode 100644 comfy_extras/v3/nodes_tomesd.py create mode 100644 comfy_extras/v3/nodes_torch_compile.py create mode 100644 comfy_extras/v3/nodes_train.py diff --git a/comfy_extras/v3/nodes_tcfg.py b/comfy_extras/v3/nodes_tcfg.py new file mode 100644 index 000000000..62c6100d6 --- /dev/null +++ b/comfy_extras/v3/nodes_tcfg.py @@ -0,0 +1,70 @@ +"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)""" + +from __future__ import annotations + +import torch + +from comfy_api.v3 import io + + +def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor: + """Drop tangential components from uncond score to align with cond score.""" + # (B, 1, ...) + batch_num = cond_score.shape[0] + cond_score_flat = cond_score.reshape(batch_num, 1, -1).float() + uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float() + + # Score matrix A (B, 2, ...) + score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1) + try: + _, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False) + except RuntimeError: + # Fallback to CPU + _, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False) + + # Drop the tangential components + v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...) + uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1 + return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype) + + +class TCFG(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TCFG_V3", + display_name="Tangential Damping CFG _V3", + category="advanced/guidance", + description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(display_name="patched_model"), + ], + ) + + @classmethod + def execute(cls, model): + m = model.clone() + + def tangential_damping_cfg(args): + # Assume [cond, uncond, ...] + x = args["input"] + conds_out = args["conds_out"] + if len(conds_out) <= 1 or None in args["conds"][:2]: + # Skip when either cond or uncond is None + return conds_out + cond_pred = conds_out[0] + uncond_pred = conds_out[1] + uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred) + uncond_pred_td = x - uncond_td + return [cond_pred, uncond_pred_td] + conds_out[2:] + + m.set_model_sampler_pre_cfg_function(tangential_damping_cfg) + return io.NodeOutput(m) + + +NODES_LIST = [ + TCFG, +] diff --git a/comfy_extras/v3/nodes_tomesd.py b/comfy_extras/v3/nodes_tomesd.py new file mode 100644 index 000000000..5032b1482 --- /dev/null +++ b/comfy_extras/v3/nodes_tomesd.py @@ -0,0 +1,190 @@ +"""Taken from: https://github.com/dbolya/tomesd""" + +from __future__ import annotations + +import math +from typing import Callable, Tuple + +import torch + +from comfy_api.v3 import io + + +def do_nothing(x: torch.Tensor, mode:str=None): + return x + + +def mps_gather_workaround(input, dim, index): + if input.shape[-1] == 1: + return torch.gather( + input.unsqueeze(-1), + dim - 1 if dim < 0 else dim, + index.unsqueeze(-1) + ).squeeze(-1) + return torch.gather(input, dim, index) + + +def bipartite_soft_matching_random2d( + metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False +) -> Tuple[Callable, Callable]: + """ + Partitions the tokens into src and dst and merges r tokens from src to dst. + Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. + Args: + - metric [B, N, C]: metric to use for similarity + - w: image width in tokens + - h: image height in tokens + - sx: stride in the x dimension for dst, must divide w + - sy: stride in the y dimension for dst, must divide h + - r: number of tokens to remove (by merging) + - no_rand: if true, disable randomness (use top left corner only) + """ + B, N, _ = metric.shape + + if r <= 0 or w == 1 or h == 1: + return do_nothing, do_nothing + + gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather + + with torch.no_grad(): + hsy, wsx = h // sy, w // sx + + # For each sy by sx kernel, randomly assign one token to be dst and the rest src + if no_rand: + rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) + else: + rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) + + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead + idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) + idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) + idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) + + # Image is not divisible by sx or sy so we need to move it into a new buffer + if (hsy * sy) < h or (wsx * sx) < w: + idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) + idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view + else: + idx_buffer = idx_buffer_view + + # We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices + rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) + + # We're finished with these + del idx_buffer, idx_buffer_view + + # rand_idx is currently dst|src, so split them + num_dst = hsy * wsx + a_idx = rand_idx[:, num_dst:, :] # src + b_idx = rand_idx[:, :num_dst, :] # dst + + def split(x): + C = x.shape[-1] + src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) + dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) + return src, dst + + # Cosine similarity between A and B + metric = metric / metric.norm(dim=-1, keepdim=True) + a, b = split(metric) + scores = a @ b.transpose(-1, -2) + + # Can't reduce more than the # tokens in src + r = min(a.shape[1], r) + + # Find the most similar greedily + node_max, node_idx = scores.max(dim=-1) + edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] + + unm_idx = edge_idx[..., r:, :] # Unmerged Tokens + src_idx = edge_idx[..., :r, :] # Merged Tokens + dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) + + def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: + src, dst = split(x) + n, t1, c = src.shape + + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) + src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) + dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) + + return torch.cat([unm, dst], dim=1) + + def unmerge(x: torch.Tensor) -> torch.Tensor: + unm_len = unm_idx.shape[1] + unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] + _, _, c = unm.shape + + src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) + + # Combine back to the original shape + out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) + out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) + out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) + + return out + + return merge, unmerge + + +def get_functions(x, ratio, original_shape): + b, c, original_h, original_w = original_shape + original_tokens = original_h * original_w + downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) + stride_x = 2 + stride_y = 2 + max_downsample = 1 + + if downsample <= max_downsample: + w = int(math.ceil(original_w / downsample)) + h = int(math.ceil(original_h / downsample)) + r = int(x.shape[1] * ratio) + no_rand = False + m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand) + return m, u + + def nothing(y): + return y + + return nothing, nothing + + +class TomePatchModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TomePatchModel_V3", + category="model_patches/unet", + inputs=[ + io.Model.Input("model"), + io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, ratio): + u = None + + def tomesd_m(q, k, v, extra_options): + nonlocal u + #NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q + #however from my basic testing it seems that using q instead gives better results + m, u = get_functions(q, ratio, extra_options["original_shape"]) + return m(q), k, v + + def tomesd_u(n, extra_options): + return u(n) + + m = model.clone() + m.set_model_attn1_patch(tomesd_m) + m.set_model_attn1_output_patch(tomesd_u) + return io.NodeOutput(m) + + +NODES_LIST = [ + TomePatchModel, +] diff --git a/comfy_extras/v3/nodes_torch_compile.py b/comfy_extras/v3/nodes_torch_compile.py new file mode 100644 index 000000000..528de0e86 --- /dev/null +++ b/comfy_extras/v3/nodes_torch_compile.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from comfy_api.torch_helpers import set_torch_compile_wrapper +from comfy_api.v3 import io + + +class TorchCompileModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TorchCompileModel_V3", + category="_for_testing", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Combo.Input("backend", options=["inductor", "cudagraphs"]), + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, backend): + m = model.clone() + set_torch_compile_wrapper(model=m, backend=backend) + return io.NodeOutput(m) + + +NODES_LIST = [ + TorchCompileModel, +] diff --git a/comfy_extras/v3/nodes_train.py b/comfy_extras/v3/nodes_train.py new file mode 100644 index 000000000..1c9290bbf --- /dev/null +++ b/comfy_extras/v3/nodes_train.py @@ -0,0 +1,658 @@ +from __future__ import annotations + +import logging +import os + +import numpy as np +import safetensors +import torch +import torch.utils.checkpoint +import tqdm +from PIL import Image, ImageDraw, ImageFont + +import comfy.model_management +import comfy.samplers +import comfy.sd +import comfy.utils +import comfy_extras.nodes_custom_sampler +import folder_paths +import node_helpers +from comfy.weight_adapter import adapters +from comfy_api.v3 import io, ui + + +def make_batch_extra_option_dict(d, indicies, full_size=None): + new_dict = {} + for k, v in d.items(): + newv = v + if isinstance(v, dict): + newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) + elif isinstance(v, torch.Tensor): + if full_size is None or v.size(0) == full_size: + newv = v[indicies] + elif isinstance(v, (list, tuple)) and len(v) == full_size: + newv = [v[i] for i in indicies] + new_dict[k] = newv + return new_dict + + +class TrainSampler(comfy.samplers.Sampler): + + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.loss_callback = loss_callback + self.batch_size = batch_size + self.total_steps = total_steps + self.seed = seed + self.training_dtype = training_dtype + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + cond = model_wrap.conds["positive"] + dataset_size = sigmas.size(0) + torch.cuda.empty_cache() + for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) + indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() + + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) + + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, + batch_noise, + batch_latent, + False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False + ) + + model_wrap.conds["positive"] = [ + cond[i] for i in indicies + ] + batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) + loss = self.loss_fn(x0_pred, x0) + loss.backward() + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + + self.optimizer.step() + self.optimizer.zero_grad() + torch.cuda.empty_cache() + return torch.zeros_like(latent_image) + + +class BiasDiff(torch.nn.Module): + def __init__(self, bias): + super().__init__() + self.bias = bias + + def __call__(self, b): + org_dtype = b.dtype + return (b.to(self.bias) + self.bias).to(org_dtype) + + def passive_memory_usage(self): + return self.bias.nelement() * self.bias.element_size() + + def move_to(self, device): + self.to(device=device) + return self.passive_memory_usage() + + +def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + + if w is None and h is None: + w, h = img.size[0], img.size[1] + + # Resize image to first image + if img.size[0] != w or img.size[1] != h: + if resize_method == "Stretch": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "Crop": + img = img.crop((0, 0, w, h)) + elif resize_method == "Pad": + img = img.resize((w, h), Image.Resampling.LANCZOS) + elif resize_method == "None": + raise ValueError( + "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." + ) + + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return torch.cat(output_images, dim=0) + + +def draw_loss_graph(loss_map, steps): + width, height = 500, 300 + img = Image.new("RGB", (width, height), "white") + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_map.values()), max(loss_map.values()) + scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()] + + prev_point = (0, height - int(scaled_loss[0] * height)) + for i, l_v in enumerate(scaled_loss[1:], start=1): + x = int(i / (steps - 1) * width) + y = height - int(l_v * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + return img + + +def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): + if result is None: + result = [] + elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + result.append(model) + logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") + return result + name = name or "root" + for next_name, child in model.named_children(): + find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}") + return result + + +def patch(m): + if not hasattr(m, "forward"): + return + org_forward = m.forward + def fwd(args, kwargs): + return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): + return torch.utils.checkpoint.checkpoint( + fwd, args, kwargs, use_reentrant=False + ) + m.org_forward = org_forward + m.forward = checkpointing_fwd + + +def unpatch(m): + if hasattr(m, "org_forward"): + m.forward = m.org_forward + del m.org_forward + + +class LoadImageSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageSetFromFolderNode_V3", + display_name="Load Image Dataset from Folder _V3", + category="loaders", + description="Loads a batch of images from a directory for training.", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from." + ), + io.Combo.Input( + "resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True + ), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, folder, resize_method="None"): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method)) + + +class LoadImageTextSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextSetFromFolderNode_V3", + display_name="Load Image and Text Dataset from Folder _V3", + category="loaders", + description="Loads a batch of images and caption from a directory for training.", + is_experimental=True, + inputs=[ + io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."), + io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."), + io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True), + io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True), + io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True), + ], + outputs=[ + io.Image.Output(), + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, folder, clip, resize_method="None", width=None, height=None): + if clip is None: + raise RuntimeError( + "ERROR: clip input is invalid: None\n\n" + "If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model." + ) + + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend([ + os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) + ] * repeat) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") + for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + width = width if width != -1 else None + height = height if height != -1 else None + output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + + logging.info(f"Encoding captions from {sub_input_dir}.") + conditions = [] + empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + for text in captions: + if text == "": + conditions.append(empty_cond) + tokens = clip.tokenize(text) + conditions.extend(clip.encode_from_tokens_scheduled(tokens)) + logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") + return io.NodeOutput(output_tensor, conditions) + + +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader_V3", + display_name="Load LoRA Model _V3", + category="loaders", + description="Load Trained LoRA weights from Train LoRA node.", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."), + io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."), + io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."), + ], + outputs=[ + io.Model.Output(tooltip="The modified diffusion model."), + ], + ) + + @classmethod + def execute(cls, model, lora, strength_model): + if strength_model == 0: + return io.NodeOutput(model) + + model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0) + return io.NodeOutput(model_lora) + + +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode_V3", + display_name="Plot Loss Graph _V3", + category="training", + description="Plots the loss graph and saves it to the output directory.", + is_experimental=True, + is_output_node=True, + inputs=[ + io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter + io.String.Input("filename_prefix", default="loss_graph"), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) + + @classmethod + def execute(cls, loss, filename_prefix): + loss_values = loss["loss"] + width, height = 800, 480 + margin = 40 + + img = Image.new( + "RGB", (width + margin, height + margin), "white" + ) # Extend canvas + draw = ImageDraw.Draw(img) + + min_loss, max_loss = min(loss_values), max(loss_values) + scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values] + + steps = len(loss_values) + + prev_point = (margin, height - int(scaled_loss[0] * height)) + for i, l_v in enumerate(scaled_loss[1:], start=1): + x = margin + int(i / steps * width) # Scale X properly + y = height - int(l_v * height) + draw.line([prev_point, (x, y)], fill="blue", width=2) + prev_point = (x, y) + + draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis + draw.line( + [(margin, height), (width + margin, height)], fill="black", width=2 + ) # X-axis + + try: + font = ImageFont.truetype("arial.ttf", 12) + except IOError: + font = ImageFont.load_default() + + # Add axis labels + draw.text((5, height // 2), "Loss", font=font, fill="black") + draw.text((width // 2, height + 10), "Steps", font=font, fill="black") + + # Add min/max loss values + draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black") + draw.text( + (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" + ) + return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls)) + + +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA_V3", + display_name="Save LoRA Weights _V3", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."), + io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."), + io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True), + ], + outputs=[], + ) + + @classmethod + def execute(cls, lora, prefix, steps=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + prefix, folder_paths.get_output_directory() + ) + if steps is None: + output_checkpoint = f"{filename}_{counter:05}_.safetensors" + else: + output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" + output_checkpoint = os.path.join(full_output_folder, output_checkpoint) + safetensors.torch.save_file(lora, output_checkpoint) + return io.NodeOutput() + + +class TrainLoraNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode_V3", + display_name="Train LoRA _V3", + category="training", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."), + io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."), + io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."), + io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."), + io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."), + io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."), + io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."), + io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."), + io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"), + io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."), + io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."), + io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."), + ], + outputs=[ + io.Model.Output(display_name="model_with_lora"), + io.LoraModel.Output(display_name="lora"), + io.LossMap.Output(display_name="loss"), + io.Int.Output(display_name="steps"), + ], + ) + + @classmethod + def execute( + cls, + model, + latents, + positive, + batch_size, + steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + existing_lora, + ): + mp = model.clone() + dtype = node_helpers.string_to_torch_dtype(training_dtype) + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + mp.set_model_compute_dtype(dtype) + + latents = latents["samples"].to(dtype) + num_images = latents.shape[0] + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + if len(positive) == 1 and num_images > 1: + positive = positive * num_images + elif len(positive) != num_images: + raise ValueError( + f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})." + ) + + with torch.inference_mode(False): + lora_sd = {} + generator = torch.Generator() + generator.manual_seed(seed) + + # Load existing LoRA weights if provided + existing_weights = {} + existing_steps = 0 + if existing_lora != "[None]": + lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora) + # Extract steps from filename like "trained_lora_10_steps_20250225_203716" + existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1]) + if lora_path: + existing_weights = comfy.utils.load_torch_file(lora_path) + + all_weight_adapters = [] + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + key = "{}.weight".format(n) + shape = m.weight.shape + if len(shape) >= 2: + alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) + dora_scale = existing_weights.get( + f"{key}.dora_scale", None + ) + for adapter_cls in adapters: + existing_adapter = adapter_cls.load( + n, existing_weights, alpha, dora_scale + ) + if existing_adapter is not None: + break + else: + # If no existing adapter found, use LoRA + # We will add algo option in the future + existing_adapter = None + adapter_cls = adapters[0] + + if existing_adapter is not None: + train_adapter = existing_adapter.to_train().to(lora_dtype) + else: + # Use LoRA with alpha=1.0 by default + train_adapter = adapter_cls.create_train( + m.weight, rank=rank, alpha=1.0 + ).to(lora_dtype) + for name, parameter in train_adapter.named_parameters(): + lora_sd[f"{n}.{name}"] = parameter + + mp.add_weight_wrapper(key, train_adapter) + all_weight_adapters.append(train_adapter) + else: + diff = torch.nn.Parameter( + torch.zeros( + m.weight.shape, dtype=lora_dtype, requires_grad=True + ) + ) + diff_module = BiasDiff(diff) + mp.add_weight_wrapper(key, BiasDiff(diff)) + all_weight_adapters.append(diff_module) + lora_sd["{}.diff".format(n)] = diff + if hasattr(m, "bias") and m.bias is not None: + key = "{}.bias".format(n) + bias = torch.nn.Parameter( + torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + ) + bias_module = BiasDiff(bias) + lora_sd["{}.diff_b".format(n)] = bias + mp.add_weight_wrapper(key, BiasDiff(bias)) + all_weight_adapters.append(bias_module) + + if optimizer == "Adam": + optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate) + elif optimizer == "AdamW": + optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate) + elif optimizer == "SGD": + optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate) + elif optimizer == "RMSprop": + optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate) + + # Setup loss function based on selection + if loss_function == "MSE": + criterion = torch.nn.MSELoss() + elif loss_function == "L1": + criterion = torch.nn.L1Loss() + elif loss_function == "Huber": + criterion = torch.nn.HuberLoss() + elif loss_function == "SmoothL1": + criterion = torch.nn.SmoothL1Loss() + + # setup models + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) + mp.model.requires_grad_(False) + comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + + # Setup sampler and guider like in test script + loss_map = {"loss": []} + def loss_callback(loss): + loss_map["loss"].append(loss) + train_sampler = TrainSampler( + criterion, + optimizer, + loss_callback=loss_callback, + batch_size=batch_size, + total_steps=steps, + seed=seed, + training_dtype=dtype + ) + guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) + guider.set_conds(positive) # Set conditioning from input + + # Training loop + try: + # Generate dummy sigmas and noise + sigmas = torch.tensor(range(num_images)) + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) + guider.sample( + noise.generate_noise({"samples": latents}), + latents, + train_sampler, + sigmas, + seed=noise.seed + ) + finally: + for m in mp.model.modules(): + unpatch(m) + del train_sampler, optimizer + + for adapter in all_weight_adapters: + adapter.requires_grad_(False) + + for param in lora_sd: + lora_sd[param] = lora_sd[param].to(lora_dtype) + + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + + +NODES_LIST = [ + LoadImageSetFromFolderNode, + LoadImageTextSetFromFolderNode, + LoraModelLoader, + LossGraphNode, + SaveLoRA, + TrainLoraNode, +] diff --git a/nodes.py b/nodes.py index 296aa0027..e2805463c 100644 --- a/nodes.py +++ b/nodes.py @@ -2350,6 +2350,10 @@ def init_builtin_extra_nodes(): "v3/nodes_sdupscale.py", "v3/nodes_slg.py", "v3/nodes_stable_cascade.py", + "v3/nodes_tcfg.py", + "v3/nodes_tomesd.py", + "v3/nodes_torch_compile.py", + "v3/nodes_train.py", "v3/nodes_upscale_model.py", "v3/nodes_video.py", "v3/nodes_video_model.py",