This commit is contained in:
bigcat88 2025-07-24 15:40:39 +03:00
parent 2ea2bc2941
commit 66cd5152fd
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721

View File

@ -17,7 +17,7 @@ import comfy.utils
import comfy_extras.nodes_custom_sampler
import folder_paths
import node_helpers
from comfy.weight_adapter import adapters
from comfy.weight_adapter import adapter_maps, adapters
from comfy_api.v3 import io, ui
@ -38,12 +38,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
class TrainSampler(comfy.samplers.Sampler):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
self.loss_fn = loss_fn
self.optimizer = optimizer
self.loss_callback = loss_callback
self.batch_size = batch_size
self.total_steps = total_steps
self.grad_acc = grad_acc
self.seed = seed
self.training_dtype = training_dtype
@ -90,8 +91,9 @@ class TrainSampler(comfy.samplers.Sampler):
self.loss_callback(loss.item())
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
self.optimizer.step()
self.optimizer.zero_grad()
if (i + 1) % self.grad_acc == 0:
self.optimizer.step()
self.optimizer.zero_grad()
torch.cuda.empty_cache()
return torch.zeros_like(latent_image)
@ -461,6 +463,7 @@ class TrainLoraNode(io.ComfyNode):
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
@ -469,6 +472,8 @@ class TrainLoraNode(io.ComfyNode):
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
],
outputs=[
@ -487,6 +492,7 @@ class TrainLoraNode(io.ComfyNode):
positive,
batch_size,
steps,
grad_accumulation_steps,
learning_rate,
rank,
optimizer,
@ -494,6 +500,8 @@ class TrainLoraNode(io.ComfyNode):
seed,
training_dtype,
lora_dtype,
algorithm,
gradient_checkpointing,
existing_lora,
):
mp = model.clone()
@ -544,10 +552,8 @@ class TrainLoraNode(io.ComfyNode):
if existing_adapter is not None:
break
else:
# If no existing adapter found, use LoRA
# We will add algo option in the future
existing_adapter = None
adapter_cls = adapters[0]
adapter_cls = adapter_maps[algorithm]
if existing_adapter is not None:
train_adapter = existing_adapter.to_train().to(lora_dtype)
@ -601,8 +607,9 @@ class TrainLoraNode(io.ComfyNode):
criterion = torch.nn.SmoothL1Loss()
# setup models
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
if gradient_checkpointing:
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
patch(m)
mp.model.requires_grad_(False)
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
@ -615,7 +622,8 @@ class TrainLoraNode(io.ComfyNode):
optimizer,
loss_callback=loss_callback,
batch_size=batch_size,
total_steps=steps,
grad_acc=grad_accumulation_steps,
total_steps=steps * grad_accumulation_steps,
seed=seed,
training_dtype=dtype
)