mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-28 08:46:35 +00:00
apply changes from https://github.com/comfyanonymous/ComfyUI/pull/9015
This commit is contained in:
parent
2ea2bc2941
commit
66cd5152fd
@ -17,7 +17,7 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_custom_sampler
|
import comfy_extras.nodes_custom_sampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.weight_adapter import adapters
|
from comfy.weight_adapter import adapter_maps, adapters
|
||||||
from comfy_api.v3 import io, ui
|
from comfy_api.v3 import io, ui
|
||||||
|
|
||||||
|
|
||||||
@ -38,12 +38,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
|
self.grad_acc = grad_acc
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
@ -90,8 +91,9 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.loss_callback(loss.item())
|
self.loss_callback(loss.item())
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
self.optimizer.step()
|
if (i + 1) % self.grad_acc == 0:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return torch.zeros_like(latent_image)
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
@ -461,6 +463,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
io.Latent.Input("latents", tooltip="The Latents to use for training, serve as dataset/input of the model."),
|
||||||
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
io.Conditioning.Input("positive", tooltip="The positive conditioning to use for training."),
|
||||||
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
io.Int.Input("batch_size", default=1, min=1, max=10000, step=1, tooltip="The batch size to use for training."),
|
||||||
|
io.Int.Input("grad_accumulation_steps", default=1, min=1, max=1024, step=1, tooltip="The number of gradient accumulation steps to use for training."),
|
||||||
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
io.Int.Input("steps", default=16, min=1, max=100000, tooltip="The number of steps to train the LoRA for."),
|
||||||
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
io.Float.Input("learning_rate", default=0.0005, min=0.0000001, max=1.0, step=0.000001, tooltip="The learning rate to use for training."),
|
||||||
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
io.Int.Input("rank", default=8, min=1, max=128, tooltip="The rank of the LoRA layers."),
|
||||||
@ -469,6 +472,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
io.Int.Input("seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)"),
|
||||||
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
io.Combo.Input("training_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for training."),
|
||||||
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
io.Combo.Input("lora_dtype", options=["bf16", "fp32"], default="bf16", tooltip="The dtype to use for lora."),
|
||||||
|
io.Combo.Input("algorithm", options=list(adapter_maps.keys()), default=list(adapter_maps.keys())[0], tooltip="The algorithm to use for training."),
|
||||||
|
io.Boolean.Input("gradient_checkpointing", default=True, tooltip="Use gradient checkpointing for training."),
|
||||||
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
io.Combo.Input("existing_lora", options=folder_paths.get_filename_list("loras") + ["[None]"], default="[None]", tooltip="The existing LoRA to append to. Set to None for new LoRA."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -487,6 +492,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
positive,
|
positive,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,
|
steps,
|
||||||
|
grad_accumulation_steps,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
rank,
|
rank,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -494,6 +500,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
algorithm,
|
||||||
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
):
|
):
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
@ -544,10 +552,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If no existing adapter found, use LoRA
|
|
||||||
# We will add algo option in the future
|
|
||||||
existing_adapter = None
|
existing_adapter = None
|
||||||
adapter_cls = adapters[0]
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
@ -601,8 +607,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
criterion = torch.nn.SmoothL1Loss()
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
# setup models
|
# setup models
|
||||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
if gradient_checkpointing:
|
||||||
patch(m)
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True)
|
||||||
|
|
||||||
@ -615,7 +622,8 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
optimizer,
|
optimizer,
|
||||||
loss_callback=loss_callback,
|
loss_callback=loss_callback,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
total_steps=steps,
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps * grad_accumulation_steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype
|
training_dtype=dtype
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user