From 66cd5152fd613b4a14580c02f02e881edd0259a2 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 24 Jul 2025 15:40:39 +0300 Subject: [PATCH] apply changes from https://github.com/comfyanonymous/ComfyUI/pull/9015 --- comfy_extras/v3/nodes_train.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/comfy_extras/v3/nodes_train.py b/comfy_extras/v3/nodes_train.py index 1c9290bbf..46888f5be 100644 --- a/comfy_extras/v3/nodes_train.py +++ b/comfy_extras/v3/nodes_train.py @@ -17,7 +17,7 @@ import comfy.utils import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers -from comfy.weight_adapter import adapters +from comfy.weight_adapter import adapter_maps, adapters from comfy_api.v3 import io, ui @@ -38,12 +38,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback self.batch_size = batch_size self.total_steps = total_steps + self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype @@ -90,8 +91,9 @@ class TrainSampler(comfy.samplers.Sampler): self.loss_callback(loss.item()) pbar.set_postfix({"loss": f"{loss.item():.4f}"}) - self.optimizer.step() - self.optimizer.zero_grad() + if (i + 1) % self.grad_acc == 0: + self.optimizer.step() + self.optimizer.zero_grad() torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -461,6 +463,7 @@ class TrainLoraNode(io.ComfyNode): io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."), io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."), io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."), + io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."), io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."), io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."), io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."), @@ -469,6 +472,8 @@ class TrainLoraNode(io.ComfyNode): io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"), io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."), io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."), + io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."), + io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."), io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."), ], outputs=[ @@ -487,6 +492,7 @@ class TrainLoraNode(io.ComfyNode): positive, batch_size, steps, + grad_accumulation_steps, learning_rate, rank, optimizer, @@ -494,6 +500,8 @@ class TrainLoraNode(io.ComfyNode): seed, training_dtype, lora_dtype, + algorithm, + gradient_checkpointing, existing_lora, ): mp = model.clone() @@ -544,10 +552,8 @@ class TrainLoraNode(io.ComfyNode): if existing_adapter is not None: break else: - # If no existing adapter found, use LoRA - # We will add algo option in the future existing_adapter = None - adapter_cls = adapters[0] + adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: train_adapter = existing_adapter.to_train().to(lora_dtype) @@ -601,8 +607,9 @@ class TrainLoraNode(io.ComfyNode): criterion = torch.nn.SmoothL1Loss() # setup models - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): - patch(m) + if gradient_checkpointing: + for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + patch(m) mp.model.requires_grad_(False) comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) @@ -615,7 +622,8 @@ class TrainLoraNode(io.ComfyNode): optimizer, loss_callback=loss_callback, batch_size=batch_size, - total_steps=steps, + grad_acc=grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, training_dtype=dtype )