mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
[Training Node] algo support, grad acc, optional grad ckpt (#9015)
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
This commit is contained in:
parent
e729a5cc11
commit
eb2f78b4e0
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
||||
OFTAdapter,
|
||||
BOFTAdapter,
|
||||
]
|
||||
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||
"LoRA": LoRAAdapter,
|
||||
"LoHa": LoHaAdapter,
|
||||
"LoKr": LoKrAdapter,
|
||||
"OFT": OFTAdapter,
|
||||
## We disable not implemented algo for now
|
||||
# "GLoRA": GLoRAAdapter,
|
||||
# "BOFT": BOFTAdapter,
|
||||
}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"WeightAdapterBase",
|
||||
"WeightAdapterTrainBase",
|
||||
"adapters"
|
||||
"adapters",
|
||||
"adapter_maps",
|
||||
] + [a.__name__ for a in adapters]
|
||||
|
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
||||
def tucker_weight(wa, wb, t):
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||
|
||||
|
||||
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||
"""
|
||||
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||
second value is higher or equal than first value.
|
||||
|
||||
examples)
|
||||
factor
|
||||
-1 2 4 8 16 ...
|
||||
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||
"""
|
||||
|
||||
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||
m = factor
|
||||
n = dimension // factor
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
if factor < 0:
|
||||
factor = dimension
|
||||
m, n = 1, dimension
|
||||
length = m + n
|
||||
while m < n:
|
||||
new_m = m + 1
|
||||
while dimension % new_m != 0:
|
||||
new_m += 1
|
||||
new_n = dimension // new_m
|
||||
if new_m + new_n > length or new_m > factor:
|
||||
break
|
||||
else:
|
||||
m, n = new_m, new_n
|
||||
if m > n:
|
||||
n, m = m, n
|
||||
return m, n
|
||||
|
@ -3,7 +3,120 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2u @ w2d)
|
||||
grad_w1u = temp @ w1d.T
|
||||
grad_w1d = w1u.T @ temp
|
||||
|
||||
temp = grad_out * (w1u @ w1d)
|
||||
grad_w2u = temp @ w2d.T
|
||||
grad_w2d = w2u.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||
del grad_temp
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||
del grad_temp
|
||||
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class LohaDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||
|
||||
self.use_tucker = False
|
||||
if t1 is not None and t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.hada_t1 = torch.nn.Parameter(t1)
|
||||
self.hada_t2 = torch.nn.Parameter(t2)
|
||||
else:
|
||||
# Keep the attributes for consistent access
|
||||
self.hada_t1 = None
|
||||
self.hada_t2 = None
|
||||
|
||||
# Store rank and non-trainable alpha
|
||||
self.rank = w1b.shape[0]
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoHaAdapter(WeightAdapterBase):
|
||||
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat1, 0.1)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
@ -3,7 +3,77 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import (
|
||||
WeightAdapterBase,
|
||||
WeightAdapterTrainBase,
|
||||
weight_decompose,
|
||||
factorization,
|
||||
)
|
||||
|
||||
|
||||
class LokrDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||
self.use_tucker = False
|
||||
if lokr_w1_a is not None:
|
||||
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||
self.w1_rebuild = True
|
||||
self.ranka = rank_a
|
||||
|
||||
if lokr_w2_a is not None:
|
||||
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||
if lokr_t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||
self.w2_rebuild = True
|
||||
self.rankb = rank_b
|
||||
|
||||
if lokr_w1 is not None:
|
||||
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||
self.w1_rebuild = False
|
||||
|
||||
if lokr_w2 is not None:
|
||||
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||
self.w2_rebuild = False
|
||||
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
@property
|
||||
def w1(self):
|
||||
if self.w1_rebuild:
|
||||
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||
else:
|
||||
return self.lokr_w1
|
||||
|
||||
@property
|
||||
def w2(self):
|
||||
if self.w2_rebuild:
|
||||
if self.use_tucker:
|
||||
w2 = torch.einsum(
|
||||
'i j k l, j r, i p -> p r k l',
|
||||
self.lokr_t2,
|
||||
self.lokr_w2_b,
|
||||
self.lokr_w2_a
|
||||
)
|
||||
else:
|
||||
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||
return w2 * (self.alpha / self.rankb)
|
||||
else:
|
||||
return self.lokr_w2
|
||||
|
||||
def __call__(self, w):
|
||||
diff = torch.kron(self.w1, self.w2)
|
||||
return w + diff.reshape(w.shape).to(w)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoKrAdapter(WeightAdapterBase):
|
||||
@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
out1, out2 = factorization(out_dim, rank)
|
||||
in1, in2 = factorization(in_dim, rank)
|
||||
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||
torch.nn.init.constant_(mat1, 0.0)
|
||||
return LokrDiff(
|
||||
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
@ -3,7 +3,58 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||
|
||||
|
||||
class OFTDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
blocks, rescale, alpha, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||
if rescale is not None:
|
||||
self.rescale = torch.nn.Parameter(rescale)
|
||||
self.rescaled = True
|
||||
else:
|
||||
self.rescaled = False
|
||||
self.block_num, self.block_size, _ = blocks.shape
|
||||
self.constraint = float(alpha)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||
|
||||
## generate r
|
||||
# for Q = -Q^T
|
||||
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||
normed_q = q
|
||||
if self.constraint:
|
||||
q_norm = torch.norm(q) + 1e-8
|
||||
if q_norm > self.constraint:
|
||||
normed_q = q * self.constraint / q_norm
|
||||
# use float() to prevent unsupported type
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
|
||||
## Apply chunked matmul on weight
|
||||
_, *shape = w.shape
|
||||
org_weight = w.to(dtype=r.dtype)
|
||||
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||
weight = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
r,
|
||||
org_weight,
|
||||
).flatten(0, 1)
|
||||
if self.rescaled:
|
||||
weight = self.rescale * weight
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class OFTAdapter(WeightAdapterBase):
|
||||
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
block_size, block_num = factorization(out_dim, rank)
|
||||
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||
return OFTDiff(
|
||||
(block, None, alpha, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return OFTDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
|
||||
blocks = v[0]
|
||||
rescale = v[1]
|
||||
alpha = v[2]
|
||||
if alpha is None:
|
||||
alpha = 0
|
||||
dora_scale = v[3]
|
||||
|
||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||
|
@ -20,7 +20,7 @@ import folder_paths
|
||||
import node_helpers
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
from comfy.weight_adapter import adapters
|
||||
from comfy.weight_adapter import adapters, adapter_maps
|
||||
|
||||
|
||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||
|
||||
|
||||
class TrainSampler(comfy.samplers.Sampler):
|
||||
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
self.loss_callback = loss_callback
|
||||
self.batch_size = batch_size
|
||||
self.total_steps = total_steps
|
||||
self.grad_acc = grad_acc
|
||||
self.seed = seed
|
||||
self.training_dtype = training_dtype
|
||||
|
||||
@ -92,6 +92,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.loss_callback(loss.item())
|
||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||
|
||||
if (i+1) % self.grad_acc == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
torch.cuda.empty_cache()
|
||||
@ -419,6 +420,16 @@ class TrainLoraNode:
|
||||
"tooltip": "The batch size to use for training.",
|
||||
},
|
||||
),
|
||||
"grad_accumulation_steps": (
|
||||
IO.INT,
|
||||
{
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 1024,
|
||||
"step": 1,
|
||||
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||
}
|
||||
),
|
||||
"steps": (
|
||||
IO.INT,
|
||||
{
|
||||
@ -478,6 +489,17 @@ class TrainLoraNode:
|
||||
["bf16", "fp32"],
|
||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||
),
|
||||
"algorithm": (
|
||||
list(adapter_maps.keys()),
|
||||
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||
),
|
||||
"gradient_checkpointing": (
|
||||
IO.BOOLEAN,
|
||||
{
|
||||
"default": True,
|
||||
"tooltip": "Use gradient checkpointing for training.",
|
||||
}
|
||||
),
|
||||
"existing_lora": (
|
||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||
{
|
||||
@ -501,6 +523,7 @@ class TrainLoraNode:
|
||||
positive,
|
||||
batch_size,
|
||||
steps,
|
||||
grad_accumulation_steps,
|
||||
learning_rate,
|
||||
rank,
|
||||
optimizer,
|
||||
@ -508,6 +531,8 @@ class TrainLoraNode:
|
||||
seed,
|
||||
training_dtype,
|
||||
lora_dtype,
|
||||
algorithm,
|
||||
gradient_checkpointing,
|
||||
existing_lora,
|
||||
):
|
||||
mp = model.clone()
|
||||
@ -558,10 +583,8 @@ class TrainLoraNode:
|
||||
if existing_adapter is not None:
|
||||
break
|
||||
else:
|
||||
# If no existing adapter found, use LoRA
|
||||
# We will add algo option in the future
|
||||
existing_adapter = None
|
||||
adapter_cls = adapters[0]
|
||||
adapter_cls = adapter_maps[algorithm]
|
||||
|
||||
if existing_adapter is not None:
|
||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||
@ -615,6 +638,7 @@ class TrainLoraNode:
|
||||
criterion = torch.nn.SmoothL1Loss()
|
||||
|
||||
# setup models
|
||||
if gradient_checkpointing:
|
||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||
patch(m)
|
||||
mp.model.requires_grad_(False)
|
||||
@ -629,7 +653,8 @@ class TrainLoraNode:
|
||||
optimizer,
|
||||
loss_callback=loss_callback,
|
||||
batch_size=batch_size,
|
||||
total_steps=steps,
|
||||
grad_acc=grad_accumulation_steps,
|
||||
total_steps=steps*grad_accumulation_steps,
|
||||
seed=seed,
|
||||
training_dtype=dtype
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user