mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
Better training loop implementation (#8820)
This commit is contained in:
parent
5612670ee4
commit
1205afc708
@ -23,38 +23,78 @@ from comfy.comfy_types.node_typing import IO
|
|||||||
from comfy.weight_adapter import adapters
|
from comfy.weight_adapter import adapters
|
||||||
|
|
||||||
|
|
||||||
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
newv = v
|
||||||
|
if isinstance(v, dict):
|
||||||
|
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
|
||||||
|
elif isinstance(v, torch.Tensor):
|
||||||
|
if full_size is None or v.size(0) == full_size:
|
||||||
|
newv = v[indicies]
|
||||||
|
elif isinstance(v, (list, tuple)) and len(v) == full_size:
|
||||||
|
newv = [v[i] for i in indicies]
|
||||||
|
new_dict[k] = newv
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None):
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.total_steps = total_steps
|
||||||
|
self.seed = seed
|
||||||
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
|
||||||
self.optimizer.zero_grad()
|
cond = model_wrap.conds["positive"]
|
||||||
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
|
dataset_size = sigmas.size(0)
|
||||||
latent = model_wrap.inner_model.model_sampling.noise_scaling(
|
torch.cuda.empty_cache()
|
||||||
torch.zeros_like(sigmas),
|
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
||||||
torch.zeros_like(noise, requires_grad=True),
|
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
|
||||||
latent_image,
|
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
|
||||||
False
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure model is in training mode and computing gradients
|
batch_latent = torch.stack([latent_image[i] for i in indicies])
|
||||||
# x0 pred
|
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
|
||||||
denoised = model_wrap(noise, sigmas, **extra_args)
|
batch_sigmas = [
|
||||||
try:
|
model_wrap.inner_model.model_sampling.percent_to_sigma(
|
||||||
loss = self.loss_fn(denoised, latent.clone())
|
torch.rand((1,)).item()
|
||||||
except RuntimeError as e:
|
) for _ in range(min(self.batch_size, dataset_size))
|
||||||
if "does not require grad and does not have a grad_fn" in str(e):
|
]
|
||||||
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
|
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
|
||||||
loss.backward()
|
|
||||||
if self.loss_callback:
|
|
||||||
self.loss_callback(loss.item())
|
|
||||||
|
|
||||||
self.optimizer.step()
|
xt = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
# torch.cuda.memory._dump_snapshot("trainn.pickle")
|
batch_sigmas,
|
||||||
# torch.cuda.memory._record_memory_history(enabled=None)
|
batch_noise,
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
|
||||||
|
torch.zeros_like(batch_sigmas),
|
||||||
|
torch.zeros_like(batch_noise),
|
||||||
|
batch_latent,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
|
model_wrap.conds["positive"] = [
|
||||||
|
cond[i] for i in indicies
|
||||||
|
]
|
||||||
|
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
|
||||||
|
|
||||||
|
with torch.autocast(xt.device.type, dtype=self.training_dtype):
|
||||||
|
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
|
||||||
|
loss = self.loss_fn(x0_pred, x0)
|
||||||
|
loss.backward()
|
||||||
|
if self.loss_callback:
|
||||||
|
self.loss_callback(loss.item())
|
||||||
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return torch.zeros_like(latent_image)
|
return torch.zeros_like(latent_image)
|
||||||
|
|
||||||
|
|
||||||
@ -584,36 +624,34 @@ class TrainLoraNode:
|
|||||||
loss_map = {"loss": []}
|
loss_map = {"loss": []}
|
||||||
def loss_callback(loss):
|
def loss_callback(loss):
|
||||||
loss_map["loss"].append(loss)
|
loss_map["loss"].append(loss)
|
||||||
pbar.set_postfix({"loss": f"{loss:.4f}"})
|
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
criterion, optimizer, loss_callback=loss_callback
|
criterion,
|
||||||
|
optimizer,
|
||||||
|
loss_callback=loss_callback,
|
||||||
|
batch_size=batch_size,
|
||||||
|
total_steps=steps,
|
||||||
|
seed=seed,
|
||||||
|
training_dtype=dtype
|
||||||
)
|
)
|
||||||
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
|
||||||
guider.set_conds(positive) # Set conditioning from input
|
guider.set_conds(positive) # Set conditioning from input
|
||||||
|
|
||||||
# yoland: this currently resize to the first image in the dataset
|
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
torch.cuda.empty_cache()
|
|
||||||
try:
|
try:
|
||||||
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
|
# Generate dummy sigmas and noise
|
||||||
# Generate random sigma
|
sigmas = torch.tensor(range(num_images))
|
||||||
sigmas = [mp.model.model_sampling.percent_to_sigma(
|
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
|
||||||
torch.rand((1,)).item()
|
guider.sample(
|
||||||
) for _ in range(min(batch_size, num_images))]
|
noise.generate_noise({"samples": latents}),
|
||||||
sigmas = torch.tensor(sigmas)
|
latents,
|
||||||
|
train_sampler,
|
||||||
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
|
sigmas,
|
||||||
|
seed=noise.seed
|
||||||
indices = torch.randperm(num_images)[:batch_size]
|
)
|
||||||
batch_latent = latents[indices].clone()
|
|
||||||
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
|
|
||||||
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
|
|
||||||
finally:
|
finally:
|
||||||
for m in mp.model.modules():
|
for m in mp.model.modules():
|
||||||
unpatch(m)
|
unpatch(m)
|
||||||
del train_sampler, optimizer
|
del train_sampler, optimizer
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
for adapter in all_weight_adapters:
|
for adapter in all_weight_adapters:
|
||||||
adapter.requires_grad_(False)
|
adapter.requires_grad_(False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user