mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
apply changes from https://github.com/comfyanonymous/ComfyUI/pull/9015
This commit is contained in:
parent
2ea2bc2941
commit
66cd5152fd
@ -17,7 +17,7 @@ import comfy.utils
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy.weight_adapter import adapters
|
||||
from comfy.weight_adapter import adapter_maps, adapters
|
||||
from comfy_api.v3 import io, ui
|
||||
|
||||
|
||||
@ -38,12 +38,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
@ -90,8 +91,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
return torch.zeros_like(latent_image)
|
||||
|
||||
@ -461,6 +463,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
|
||||
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||
@ -469,6 +472,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
|
||||
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
|
||||
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||
],
|
||||
outputs=[
|
||||
@ -487,6 +492,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
grad_accumulation_steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
@ -494,6 +500,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
@ -544,10 +552,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
@ -601,8 +607,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||
|
||||
@ -615,7 +622,8 @@ class TrainLoraNode(io.ComfyNode):
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
total_steps=steps,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps * grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user