mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 00:36:32 +00:00
converted nodes files starting with "t" letter
This commit is contained in:
parent
487ec28b9c
commit
2ea2bc2941
70
comfy_extras/v3/nodes_tcfg.py
Normal file
70
comfy_extras/v3/nodes_tcfg.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""TCFG: Tangential Damping Classifier-free Guidance - (arXiv: https://arxiv.org/abs/2503.18137)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def score_tangential_damping(cond_score: torch.Tensor, uncond_score: torch.Tensor) -> torch.Tensor:
|
||||
"""Drop tangential components from uncond score to align with cond score."""
|
||||
# (B, 1, ...)
|
||||
batch_num = cond_score.shape[0]
|
||||
cond_score_flat = cond_score.reshape(batch_num, 1, -1).float()
|
||||
uncond_score_flat = uncond_score.reshape(batch_num, 1, -1).float()
|
||||
|
||||
# Score matrix A (B, 2, ...)
|
||||
score_matrix = torch.cat((uncond_score_flat, cond_score_flat), dim=1)
|
||||
try:
|
||||
_, _, Vh = torch.linalg.svd(score_matrix, full_matrices=False)
|
||||
except RuntimeError:
|
||||
# Fallback to CPU
|
||||
_, _, Vh = torch.linalg.svd(score_matrix.cpu(), full_matrices=False)
|
||||
|
||||
# Drop the tangential components
|
||||
v1 = Vh[:, 0:1, :].to(uncond_score_flat.device) # (B, 1, ...)
|
||||
uncond_score_td = (uncond_score_flat @ v1.transpose(-2, -1)) * v1
|
||||
return uncond_score_td.reshape_as(uncond_score).to(uncond_score.dtype)
|
||||
|
||||
|
||||
class TCFG(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TCFG_V3",
|
||||
display_name="Tangential Damping CFG _V3",
|
||||
category="advanced/guidance",
|
||||
description="TCFG – Tangential Damping CFG (2503.18137)\n\nRefine the uncond (negative) to align with the cond (positive) for improving quality.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="patched_model"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model):
|
||||
m = model.clone()
|
||||
|
||||
def tangential_damping_cfg(args):
|
||||
# Assume [cond, uncond, ...]
|
||||
x = args["input"]
|
||||
conds_out = args["conds_out"]
|
||||
if len(conds_out) <= 1 or None in args["conds"][:2]:
|
||||
# Skip when either cond or uncond is None
|
||||
return conds_out
|
||||
cond_pred = conds_out[0]
|
||||
uncond_pred = conds_out[1]
|
||||
uncond_td = score_tangential_damping(x - cond_pred, x - uncond_pred)
|
||||
uncond_pred_td = x - uncond_td
|
||||
return [cond_pred, uncond_pred_td] + conds_out[2:]
|
||||
|
||||
m.set_model_sampler_pre_cfg_function(tangential_damping_cfg)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TCFG,
|
||||
]
|
190
comfy_extras/v3/nodes_tomesd.py
Normal file
190
comfy_extras/v3/nodes_tomesd.py
Normal file
@ -0,0 +1,190 @@
|
||||
"""Taken from: https://github.com/dbolya/tomesd"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
def do_nothing(x: torch.Tensor, mode:str=None):
|
||||
return x
|
||||
|
||||
|
||||
def mps_gather_workaround(input, dim, index):
|
||||
if input.shape[-1] == 1:
|
||||
return torch.gather(
|
||||
input.unsqueeze(-1),
|
||||
dim - 1 if dim < 0 else dim,
|
||||
index.unsqueeze(-1)
|
||||
).squeeze(-1)
|
||||
return torch.gather(input, dim, index)
|
||||
|
||||
|
||||
def bipartite_soft_matching_random2d(
|
||||
metric: torch.Tensor,w: int, h: int, sx: int, sy: int, r: int, no_rand: bool = False
|
||||
) -> Tuple[Callable, Callable]:
|
||||
"""
|
||||
Partitions the tokens into src and dst and merges r tokens from src to dst.
|
||||
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
|
||||
Args:
|
||||
- metric [B, N, C]: metric to use for similarity
|
||||
- w: image width in tokens
|
||||
- h: image height in tokens
|
||||
- sx: stride in the x dimension for dst, must divide w
|
||||
- sy: stride in the y dimension for dst, must divide h
|
||||
- r: number of tokens to remove (by merging)
|
||||
- no_rand: if true, disable randomness (use top left corner only)
|
||||
"""
|
||||
B, N, _ = metric.shape
|
||||
|
||||
if r <= 0 or w == 1 or h == 1:
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
with torch.no_grad():
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
if no_rand:
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)
|
||||
|
||||
# Image is not divisible by sx or sy so we need to move it into a new buffer
|
||||
if (hsy * sy) < h or (wsx * sx) < w:
|
||||
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view
|
||||
else:
|
||||
idx_buffer = idx_buffer_view
|
||||
|
||||
# We set dst tokens to be -1 and src to be 0, so an argsort gives us dst|src indices
|
||||
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1)
|
||||
|
||||
# We're finished with these
|
||||
del idx_buffer, idx_buffer_view
|
||||
|
||||
# rand_idx is currently dst|src, so split them
|
||||
num_dst = hsy * wsx
|
||||
a_idx = rand_idx[:, num_dst:, :] # src
|
||||
b_idx = rand_idx[:, :num_dst, :] # dst
|
||||
|
||||
def split(x):
|
||||
C = x.shape[-1]
|
||||
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C))
|
||||
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C))
|
||||
return src, dst
|
||||
|
||||
# Cosine similarity between A and B
|
||||
metric = metric / metric.norm(dim=-1, keepdim=True)
|
||||
a, b = split(metric)
|
||||
scores = a @ b.transpose(-1, -2)
|
||||
|
||||
# Can't reduce more than the # tokens in src
|
||||
r = min(a.shape[1], r)
|
||||
|
||||
# Find the most similar greedily
|
||||
node_max, node_idx = scores.max(dim=-1)
|
||||
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]
|
||||
|
||||
unm_idx = edge_idx[..., r:, :] # Unmerged Tokens
|
||||
src_idx = edge_idx[..., :r, :] # Merged Tokens
|
||||
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
|
||||
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
return torch.cat([unm, dst], dim=1)
|
||||
|
||||
def unmerge(x: torch.Tensor) -> torch.Tensor:
|
||||
unm_len = unm_idx.shape[1]
|
||||
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
|
||||
_, _, c = unm.shape
|
||||
|
||||
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c))
|
||||
|
||||
# Combine back to the original shape
|
||||
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype)
|
||||
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm)
|
||||
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src)
|
||||
|
||||
return out
|
||||
|
||||
return merge, unmerge
|
||||
|
||||
|
||||
def get_functions(x, ratio, original_shape):
|
||||
b, c, original_h, original_w = original_shape
|
||||
original_tokens = original_h * original_w
|
||||
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
|
||||
stride_x = 2
|
||||
stride_y = 2
|
||||
max_downsample = 1
|
||||
|
||||
if downsample <= max_downsample:
|
||||
w = int(math.ceil(original_w / downsample))
|
||||
h = int(math.ceil(original_h / downsample))
|
||||
r = int(x.shape[1] * ratio)
|
||||
no_rand = False
|
||||
m, u = bipartite_soft_matching_random2d(x, w, h, stride_x, stride_y, r, no_rand)
|
||||
return m, u
|
||||
|
||||
def nothing(y):
|
||||
return y
|
||||
|
||||
return nothing, nothing
|
||||
|
||||
|
||||
class TomePatchModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TomePatchModel_V3",
|
||||
category="model_patches/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, ratio):
|
||||
u = None
|
||||
|
||||
def tomesd_m(q, k, v, extra_options):
|
||||
nonlocal u
|
||||
#NOTE: In the reference code get_functions takes x (input of the transformer block) as the argument instead of q
|
||||
#however from my basic testing it seems that using q instead gives better results
|
||||
m, u = get_functions(q, ratio, extra_options["original_shape"])
|
||||
return m(q), k, v
|
||||
|
||||
def tomesd_u(n, extra_options):
|
||||
return u(n)
|
||||
|
||||
m = model.clone()
|
||||
m.set_model_attn1_patch(tomesd_m)
|
||||
m.set_model_attn1_output_patch(tomesd_u)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TomePatchModel,
|
||||
]
|
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
32
comfy_extras/v3/nodes_torch_compile.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
from comfy_api.v3 import io
|
||||
|
||||
|
||||
class TorchCompileModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TorchCompileModel_V3",
|
||||
category="_for_testing",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("backend", options=["inductor", "cudagraphs"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, backend):
|
||||
m = model.clone()
|
||||
set_torch_compile_wrapper(model=m, backend=backend)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
TorchCompileModel,
|
||||
]
|
658
comfy_extras/v3/nodes_train.py
Normal file
658
comfy_extras/v3/nodes_train.py
Normal file
@ -0,0 +1,658 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import tqdm
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.samplers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.weight_adapter import adapters
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
new_dict = {}
|
||||
for k, v in d.items():
|
||||
newv = v
|
||||
if isinstance(v, dict):
|
||||
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
if full_size is None or v.size(0) == full_size:
|
||||
newv = v[indicies]
|
||||
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||
newv = [v[i] for i in indicies]
|
||||
new_dict[k] = newv
|
||||
return new_dict
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||
cond = model_wrap.conds["positive"]
|
||||
dataset_size = sigmas.size(0)
|
||||
torch.cuda.empty_cache()
|
||||
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||
|
||||
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||
batch_sigmas = [
|
||||
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||
torch.rand((1,)).item()
|
||||
) for _ in range(min(self.batch_size, dataset_size))
|
||||
]
|
||||
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||
|
||||
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
batch_sigmas,
|
||||
batch_noise,
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||
torch.zeros_like(batch_sigmas),
|
||||
torch.zeros_like(batch_noise),
|
||||
batch_latent,
|
||||
False
|
||||
)
|
||||
|
||||
model_wrap.conds["positive"] = [
|
||||
cond[i] for i in indicies
|
||||
]
|
||||
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||
|
||||
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
|
||||
class BiasDiff(torch.nn.Module):
|
||||
def __init__(self, bias):
|
||||
super().__init__()
|
||||
self.bias = bias
|
||||
|
||||
def __call__(self, b):
|
||||
org_dtype = b.dtype
|
||||
return (b.to(self.bias) + self.bias).to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return self.bias.nelement() * self.bias.element_size()
|
||||
|
||||
def move_to(self, device):
|
||||
self.to(device=device)
|
||||
return self.passive_memory_usage()
|
||||
|
||||
|
||||
def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None):
|
||||
"""Utility function to load and process a list of images.
|
||||
|
||||
Args:
|
||||
image_files: List of image filenames
|
||||
input_dir: Base directory containing the images
|
||||
resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad")
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Batch of processed images
|
||||
"""
|
||||
if not image_files:
|
||||
raise ValueError("No valid images found in input")
|
||||
|
||||
output_images = []
|
||||
|
||||
for file in image_files:
|
||||
image_path = os.path.join(input_dir, file)
|
||||
img = node_helpers.pillow(Image.open, image_path)
|
||||
|
||||
if img.mode == "I":
|
||||
img = img.point(lambda i: i * (1 / 255))
|
||||
img = img.convert("RGB")
|
||||
|
||||
if w is None and h is None:
|
||||
w, h = img.size[0], img.size[1]
|
||||
|
||||
# Resize image to first image
|
||||
if img.size[0] != w or img.size[1] != h:
|
||||
if resize_method == "Stretch":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "Crop":
|
||||
img = img.crop((0, 0, w, h))
|
||||
elif resize_method == "Pad":
|
||||
img = img.resize((w, h), Image.Resampling.LANCZOS)
|
||||
elif resize_method == "None":
|
||||
raise ValueError(
|
||||
"Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images."
|
||||
)
|
||||
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
img_tensor = torch.from_numpy(img_array)[None,]
|
||||
output_images.append(img_tensor)
|
||||
|
||||
return torch.cat(output_images, dim=0)
|
||||
|
||||
|
||||
def draw_loss_graph(loss_map, steps):
|
||||
width, height = 500, 300
|
||||
img = Image.new("RGB", (width, height), "white")
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_map.values()), max(loss_map.values())
|
||||
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_map.values()]
|
||||
|
||||
prev_point = (0, height - int(scaled_loss[0] * height))
|
||||
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||
x = int(i / (steps - 1) * width)
|
||||
y = height - int(l_v * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None):
|
||||
if result is None:
|
||||
result = []
|
||||
elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)):
|
||||
result.append(model)
|
||||
logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})")
|
||||
return result
|
||||
name = name or "root"
|
||||
for next_name, child in model.named_children():
|
||||
find_all_highest_child_module_with_forward(child, result, f"{name}.{next_name}")
|
||||
return result
|
||||
|
||||
|
||||
def patch(m):
|
||||
if not hasattr(m, "forward"):
|
||||
return
|
||||
org_forward = m.forward
|
||||
def fwd(args, kwargs):
|
||||
return org_forward(*args, **kwargs)
|
||||
def checkpointing_fwd(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(
|
||||
fwd, args, kwargs, use_reentrant=False
|
||||
)
|
||||
m.org_forward = org_forward
|
||||
m.forward = checkpointing_fwd
|
||||
|
||||
|
||||
def unpatch(m):
|
||||
if hasattr(m, "org_forward"):
|
||||
m.forward = m.org_forward
|
||||
del m.org_forward
|
||||
|
||||
|
||||
class LoadImageSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageSetFromFolderNode_V3",
|
||||
display_name="Load Image Dataset from Folder _V3",
|
||||
category="loaders",
|
||||
description="Loads a batch of images from a directory for training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."
|
||||
),
|
||||
io.Combo.Input(
|
||||
"resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, resize_method="None"):
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
image_files = [
|
||||
f
|
||||
for f in os.listdir(sub_input_dir)
|
||||
if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
]
|
||||
return io.NodeOutput(load_and_process_images(image_files, sub_input_dir, resize_method))
|
||||
|
||||
|
||||
class LoadImageTextSetFromFolderNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadImageTextSetFromFolderNode_V3",
|
||||
display_name="Load Image and Text Dataset from Folder _V3",
|
||||
category="loaders",
|
||||
description="Loads a batch of images and caption from a directory for training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Combo.Input("folder", options=folder_paths.get_input_subfolders(), tooltip="The folder to load images from."),
|
||||
io.Clip.Input("clip", tooltip="The CLIP model used for encoding the text."),
|
||||
io.Combo.Input("resize_method", options=["None", "Stretch", "Crop", "Pad"], default="None", optional=True),
|
||||
io.Int.Input("width", default=-1, min=-1, max=10000, step=1, tooltip="The width to resize the images to. -1 means use the original width.", optional=True),
|
||||
io.Int.Input("height", default=-1, min=-1, max=10000, step=1, tooltip="The height to resize the images to. -1 means use the original height.", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(),
|
||||
io.Conditioning.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, folder, clip, resize_method="None", width=None, height=None):
|
||||
if clip is None:
|
||||
raise RuntimeError(
|
||||
"ERROR: clip input is invalid: None\n\n"
|
||||
"If the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model."
|
||||
)
|
||||
|
||||
logging.info(f"Loading images from folder: {folder}")
|
||||
|
||||
sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder)
|
||||
valid_extensions = [".png", ".jpg", ".jpeg", ".webp"]
|
||||
|
||||
image_files = []
|
||||
for item in os.listdir(sub_input_dir):
|
||||
path = os.path.join(sub_input_dir, item)
|
||||
if any(item.lower().endswith(ext) for ext in valid_extensions):
|
||||
image_files.append(path)
|
||||
elif os.path.isdir(path):
|
||||
# Support kohya-ss/sd-scripts folder structure
|
||||
repeat = 1
|
||||
if item.split("_")[0].isdigit():
|
||||
repeat = int(item.split("_")[0])
|
||||
image_files.extend([
|
||||
os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions)
|
||||
] * repeat)
|
||||
|
||||
caption_file_path = [
|
||||
f.replace(os.path.splitext(f)[1], ".txt")
|
||||
for f in image_files
|
||||
]
|
||||
captions = []
|
||||
for caption_file in caption_file_path:
|
||||
caption_path = os.path.join(sub_input_dir, caption_file)
|
||||
if os.path.exists(caption_path):
|
||||
with open(caption_path, "r", encoding="utf-8") as f:
|
||||
caption = f.read().strip()
|
||||
captions.append(caption)
|
||||
else:
|
||||
captions.append("")
|
||||
|
||||
width = width if width != -1 else None
|
||||
height = height if height != -1 else None
|
||||
output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height)
|
||||
|
||||
logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.")
|
||||
|
||||
logging.info(f"Encoding captions from {sub_input_dir}.")
|
||||
conditions = []
|
||||
empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize(""))
|
||||
for text in captions:
|
||||
if text == "":
|
||||
conditions.append(empty_cond)
|
||||
tokens = clip.tokenize(text)
|
||||
conditions.extend(clip.encode_from_tokens_scheduled(tokens))
|
||||
logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.")
|
||||
return io.NodeOutput(output_tensor, conditions)
|
||||
|
||||
|
||||
class LoraModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoraModelLoader_V3",
|
||||
display_name="Load LoRA Model _V3",
|
||||
category="loaders",
|
||||
description="Load Trained LoRA weights from Train LoRA node.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The diffusion model the LoRA will be applied to."),
|
||||
io.LoraModel.Input("lora", tooltip="The LoRA model to apply to the diffusion model."),
|
||||
io.Float.Input("strength_model", default=1.0, min=-100.0, max=100.0, step=0.01, tooltip="How strongly to modify the diffusion model. This value can be negative."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The modified diffusion model."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, lora, strength_model):
|
||||
if strength_model == 0:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0)
|
||||
return io.NodeOutput(model_lora)
|
||||
|
||||
|
||||
class LossGraphNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LossGraphNode_V3",
|
||||
display_name="Plot Loss Graph _V3",
|
||||
category="training",
|
||||
description="Plots the loss graph and saves it to the output directory.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.LossMap.Input("loss"), # TODO: original V1 node has also `default={}` parameter
|
||||
io.String.Input("filename_prefix", default="loss_graph"),
|
||||
],
|
||||
outputs=[],
|
||||
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, loss, filename_prefix):
|
||||
loss_values = loss["loss"]
|
||||
width, height = 800, 480
|
||||
margin = 40
|
||||
|
||||
img = Image.new(
|
||||
"RGB", (width + margin, height + margin), "white"
|
||||
) # Extend canvas
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
min_loss, max_loss = min(loss_values), max(loss_values)
|
||||
scaled_loss = [(l_v - min_loss) / (max_loss - min_loss) for l_v in loss_values]
|
||||
|
||||
steps = len(loss_values)
|
||||
|
||||
prev_point = (margin, height - int(scaled_loss[0] * height))
|
||||
for i, l_v in enumerate(scaled_loss[1:], start=1):
|
||||
x = margin + int(i / steps * width) # Scale X properly
|
||||
y = height - int(l_v * height)
|
||||
draw.line([prev_point, (x, y)], fill="blue", width=2)
|
||||
prev_point = (x, y)
|
||||
|
||||
draw.line([(margin, 0), (margin, height)], fill="black", width=2) # Y-axis
|
||||
draw.line(
|
||||
[(margin, height), (width + margin, height)], fill="black", width=2
|
||||
) # X-axis
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 12)
|
||||
except IOError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
# Add axis labels
|
||||
draw.text((5, height // 2), "Loss", font=font, fill="black")
|
||||
draw.text((width // 2, height + 10), "Steps", font=font, fill="black")
|
||||
|
||||
# Add min/max loss values
|
||||
draw.text((margin - 30, 0), f"{max_loss:.2f}", font=font, fill="black")
|
||||
draw.text(
|
||||
(margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black"
|
||||
)
|
||||
return io.NodeOutput(ui=ui.PreviewImage(img, cls=cls))
|
||||
|
||||
|
||||
class SaveLoRA(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SaveLoRA_V3",
|
||||
display_name="Save LoRA Weights _V3",
|
||||
category="loaders",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
io.LoraModel.Input("lora", tooltip="The LoRA model to save. Do not use the model with LoRA layers."),
|
||||
io.String.Input("prefix", default="loras/ComfyUI_trained_lora", tooltip="The prefix to use for the saved LoRA file."),
|
||||
io.Int.Input("steps", tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", optional=True),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, lora, prefix, steps=None):
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
|
||||
prefix, folder_paths.get_output_directory()
|
||||
)
|
||||
if steps is None:
|
||||
output_checkpoint = f"{filename}_{counter:05}_.safetensors"
|
||||
else:
|
||||
output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors"
|
||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||
safetensors.torch.save_file(lora, output_checkpoint)
|
||||
return io.NodeOutput()
|
||||
|
||||
|
||||
class TrainLoraNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="TrainLoraNode_V3",
|
||||
display_name="Train LoRA _V3",
|
||||
category="training",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to train the LoRA on."),
|
||||
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||
io.Combo.Input("optimizer", options=["AdamW", "Adam", "SGD", "RMSprop"], default="AdamW", tooltip="The optimizer to use for training."),
|
||||
io.Combo.Input("loss_function", options=["MSE", "L1", "Huber", "SmoothL1"], default="MSE", tooltip="The loss function to use for training."),
|
||||
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name="model_with_lora"),
|
||||
io.LoraModel.Output(display_name="lora"),
|
||||
io.LossMap.Output(display_name="loss"),
|
||||
io.Int.Output(display_name="steps"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
model,
|
||||
latents,
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
loss_function,
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
latents = latents["samples"].to(dtype)
|
||||
num_images = latents.shape[0]
|
||||
logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}")
|
||||
if len(positive) == 1 and num_images > 1:
|
||||
positive = positive * num_images
|
||||
elif len(positive) != num_images:
|
||||
raise ValueError(
|
||||
f"Number of positive conditions ({len(positive)}) does not match number of images ({num_images})."
|
||||
)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
lora_sd = {}
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights = {}
|
||||
existing_steps = 0
|
||||
if existing_lora != "[None]":
|
||||
lora_path = folder_paths.get_full_path_or_raise("loras", existing_lora)
|
||||
# Extract steps from filename like "trained_lora_10_steps_20250225_203716"
|
||||
existing_steps = int(existing_lora.split("_steps_")[0].split("_")[-1])
|
||||
if lora_path:
|
||||
existing_weights = comfy.utils.load_torch_file(lora_path)
|
||||
|
||||
all_weight_adapters = []
|
||||
for n, m in mp.model.named_modules():
|
||||
if hasattr(m, "weight_function"):
|
||||
if m.weight is not None:
|
||||
key = "{}.weight".format(n)
|
||||
shape = m.weight.shape
|
||||
if len(shape) >= 2:
|
||||
alpha = float(existing_weights.get(f"{key}.alpha", 1.0))
|
||||
dora_scale = existing_weights.get(
|
||||
f"{key}.dora_scale", None
|
||||
)
|
||||
for adapter_cls in adapters:
|
||||
existing_adapter = adapter_cls.load(
|
||||
n, existing_weights, alpha, dora_scale
|
||||
)
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
else:
|
||||
# Use LoRA with alpha=1.0 by default
|
||||
train_adapter = adapter_cls.create_train(
|
||||
m.weight, rank=rank, alpha=1.0
|
||||
).to(lora_dtype)
|
||||
for name, parameter in train_adapter.named_parameters():
|
||||
lora_sd[f"{n}.{name}"] = parameter
|
||||
|
||||
mp.add_weight_wrapper(key, train_adapter)
|
||||
all_weight_adapters.append(train_adapter)
|
||||
else:
|
||||
diff = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
m.weight.shape, dtype=lora_dtype, requires_grad=True
|
||||
)
|
||||
)
|
||||
diff_module = BiasDiff(diff)
|
||||
mp.add_weight_wrapper(key, BiasDiff(diff))
|
||||
all_weight_adapters.append(diff_module)
|
||||
lora_sd["{}.diff".format(n)] = diff
|
||||
if hasattr(m, "bias") and m.bias is not None:
|
||||
key = "{}.bias".format(n)
|
||||
bias = torch.nn.Parameter(
|
||||
torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True)
|
||||
)
|
||||
bias_module = BiasDiff(bias)
|
||||
lora_sd["{}.diff_b".format(n)] = bias
|
||||
mp.add_weight_wrapper(key, BiasDiff(bias))
|
||||
all_weight_adapters.append(bias_module)
|
||||
|
||||
if optimizer == "Adam":
|
||||
optimizer = torch.optim.Adam(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "AdamW":
|
||||
optimizer = torch.optim.AdamW(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "SGD":
|
||||
optimizer = torch.optim.SGD(lora_sd.values(), lr=learning_rate)
|
||||
elif optimizer == "RMSprop":
|
||||
optimizer = torch.optim.RMSprop(lora_sd.values(), lr=learning_rate)
|
||||
|
||||
# Setup loss function based on selection
|
||||
if loss_function == "MSE":
|
||||
criterion = torch.nn.MSELoss()
|
||||
elif loss_function == "L1":
|
||||
criterion = torch.nn.L1Loss()
|
||||
elif loss_function == "Huber":
|
||||
criterion = torch.nn.HuberLoss()
|
||||
elif loss_function == "SmoothL1":
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
# Setup sampler and guider like in test script
|
||||
loss_map = {"loss": []}
|
||||
def loss_callback(loss):
|
||||
loss_map["loss"].append(loss)
|
||||
train_sampler = TrainSampler(
|
||||
criterion,
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
total_steps=steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||
guider.set_conds(positive) # Set conditioning from input
|
||||
|
||||
# Training loop
|
||||
try:
|
||||
# Generate dummy sigmas and noise
|
||||
sigmas = torch.tensor(range(num_images))
|
||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||
guider.sample(
|
||||
noise.generate_noise({"samples": latents}),
|
||||
latents,
|
||||
train_sampler,
|
||||
sigmas,
|
||||
seed=noise.seed
|
||||
)
|
||||
finally:
|
||||
for m in mp.model.modules():
|
||||
unpatch(m)
|
||||
del train_sampler, optimizer
|
||||
|
||||
for adapter in all_weight_adapters:
|
||||
adapter.requires_grad_(False)
|
||||
|
||||
for param in lora_sd:
|
||||
lora_sd[param] = lora_sd[param].to(lora_dtype)
|
||||
|
||||
return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps)
|
||||
|
||||
|
||||
NODES_LIST = [
|
||||
LoadImageSetFromFolderNode,
|
||||
LoadImageTextSetFromFolderNode,
|
||||
LoraModelLoader,
|
||||
LossGraphNode,
|
||||
SaveLoRA,
|
||||
TrainLoraNode,
|
||||
]
|
4
nodes.py
4
nodes.py
@ -2350,6 +2350,10 @@ def init_builtin_extra_nodes():
|
||||
"v3/nodes_sdupscale.py",
|
||||
"v3/nodes_slg.py",
|
||||
"v3/nodes_stable_cascade.py",
|
||||
"v3/nodes_tcfg.py",
|
||||
"v3/nodes_tomesd.py",
|
||||
"v3/nodes_torch_compile.py",
|
||||
"v3/nodes_train.py",
|
||||
"v3/nodes_upscale_model.py",
|
||||
"v3/nodes_video.py",
|
||||
"v3/nodes_video_model.py",
|
||||
|
Loading…
x
Reference in New Issue
Block a user