mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
[Training Node] algo support, grad acc, optional grad ckpt (#9015)
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
This commit is contained in:
parent
e729a5cc11
commit
eb2f78b4e0
@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
|
|||||||
OFTAdapter,
|
OFTAdapter,
|
||||||
BOFTAdapter,
|
BOFTAdapter,
|
||||||
]
|
]
|
||||||
|
adapter_maps: dict[str, type[WeightAdapterBase]] = {
|
||||||
|
"LoRA": LoRAAdapter,
|
||||||
|
"LoHa": LoHaAdapter,
|
||||||
|
"LoKr": LoKrAdapter,
|
||||||
|
"OFT": OFTAdapter,
|
||||||
|
## We disable not implemented algo for now
|
||||||
|
# "GLoRA": GLoRAAdapter,
|
||||||
|
# "BOFT": BOFTAdapter,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"WeightAdapterBase",
|
"WeightAdapterBase",
|
||||||
"WeightAdapterTrainBase",
|
"WeightAdapterTrainBase",
|
||||||
"adapters"
|
"adapters",
|
||||||
|
"adapter_maps",
|
||||||
] + [a.__name__ for a in adapters]
|
] + [a.__name__ for a in adapters]
|
||||||
|
@ -133,3 +133,43 @@ def tucker_weight_from_conv(up, down, mid):
|
|||||||
def tucker_weight(wa, wb, t):
|
def tucker_weight(wa, wb, t):
|
||||||
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
temp = torch.einsum("i j ..., j r -> i r ...", t, wb)
|
||||||
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
return torch.einsum("i j ..., i r -> r j ...", temp, wa)
|
||||||
|
|
||||||
|
|
||||||
|
def factorization(dimension: int, factor: int = -1) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
return a tuple of two value of input dimension decomposed by the number closest to factor
|
||||||
|
second value is higher or equal than first value.
|
||||||
|
|
||||||
|
examples)
|
||||||
|
factor
|
||||||
|
-1 2 4 8 16 ...
|
||||||
|
127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127
|
||||||
|
128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16
|
||||||
|
250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25
|
||||||
|
360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30
|
||||||
|
512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32
|
||||||
|
1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64
|
||||||
|
"""
|
||||||
|
|
||||||
|
if factor > 0 and (dimension % factor) == 0 and dimension >= factor**2:
|
||||||
|
m = factor
|
||||||
|
n = dimension // factor
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
if factor < 0:
|
||||||
|
factor = dimension
|
||||||
|
m, n = 1, dimension
|
||||||
|
length = m + n
|
||||||
|
while m < n:
|
||||||
|
new_m = m + 1
|
||||||
|
while dimension % new_m != 0:
|
||||||
|
new_m += 1
|
||||||
|
new_n = dimension // new_m
|
||||||
|
if new_m + new_n > length or new_m > factor:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
m, n = new_m, new_n
|
||||||
|
if m > n:
|
||||||
|
n, m = m, n
|
||||||
|
return m, n
|
||||||
|
@ -3,7 +3,120 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeight(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||||
|
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||||
|
return diff_weight
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
temp = grad_out * (w2u @ w2d)
|
||||||
|
grad_w1u = temp @ w1d.T
|
||||||
|
grad_w1d = w1u.T @ temp
|
||||||
|
|
||||||
|
temp = grad_out * (w1u @ w1d)
|
||||||
|
grad_w2u = temp @ w2d.T
|
||||||
|
grad_w2d = w2u.T @ temp
|
||||||
|
|
||||||
|
del temp
|
||||||
|
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class HadaWeightTucker(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||||
|
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||||
|
|
||||||
|
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||||
|
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||||
|
|
||||||
|
return rebuild1 * rebuild2 * scale
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_out):
|
||||||
|
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||||
|
grad_out = grad_out * scale
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||||
|
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||||
|
del grad_temp
|
||||||
|
|
||||||
|
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||||
|
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||||
|
|
||||||
|
grad_w = rebuild * grad_out
|
||||||
|
del rebuild
|
||||||
|
|
||||||
|
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||||
|
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||||
|
del grad_w, temp
|
||||||
|
|
||||||
|
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||||
|
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||||
|
del grad_temp
|
||||||
|
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||||
|
|
||||||
|
|
||||||
|
class LohaDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||||
|
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||||
|
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||||
|
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||||
|
|
||||||
|
self.use_tucker = False
|
||||||
|
if t1 is not None and t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.hada_t1 = torch.nn.Parameter(t1)
|
||||||
|
self.hada_t2 = torch.nn.Parameter(t2)
|
||||||
|
else:
|
||||||
|
# Keep the attributes for consistent access
|
||||||
|
self.hada_t1 = None
|
||||||
|
self.hada_t2 = None
|
||||||
|
|
||||||
|
# Store rank and non-trainable alpha
|
||||||
|
self.rank = w1b.shape[0]
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
|
||||||
|
scale = self.alpha / self.rank
|
||||||
|
if self.use_tucker:
|
||||||
|
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
else:
|
||||||
|
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||||
|
|
||||||
|
# Add the scaled difference to the original weight
|
||||||
|
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||||
|
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat1, 0.1)
|
||||||
|
torch.nn.init.constant_(mat2, 0.0)
|
||||||
|
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.normal_(mat3, 0.1)
|
||||||
|
torch.nn.init.normal_(mat4, 0.01)
|
||||||
|
return LohaDiff(
|
||||||
|
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return LohaDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -3,7 +3,77 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import (
|
||||||
|
WeightAdapterBase,
|
||||||
|
WeightAdapterTrainBase,
|
||||||
|
weight_decompose,
|
||||||
|
factorization,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LokrDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
||||||
|
self.use_tucker = False
|
||||||
|
if lokr_w1_a is not None:
|
||||||
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
||||||
|
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
||||||
|
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
||||||
|
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
||||||
|
self.w1_rebuild = True
|
||||||
|
self.ranka = rank_a
|
||||||
|
|
||||||
|
if lokr_w2_a is not None:
|
||||||
|
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
||||||
|
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
||||||
|
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
||||||
|
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
||||||
|
if lokr_t2 is not None:
|
||||||
|
self.use_tucker = True
|
||||||
|
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
||||||
|
self.w2_rebuild = True
|
||||||
|
self.rankb = rank_b
|
||||||
|
|
||||||
|
if lokr_w1 is not None:
|
||||||
|
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
||||||
|
self.w1_rebuild = False
|
||||||
|
|
||||||
|
if lokr_w2 is not None:
|
||||||
|
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
||||||
|
self.w2_rebuild = False
|
||||||
|
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w1(self):
|
||||||
|
if self.w1_rebuild:
|
||||||
|
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
||||||
|
else:
|
||||||
|
return self.lokr_w1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def w2(self):
|
||||||
|
if self.w2_rebuild:
|
||||||
|
if self.use_tucker:
|
||||||
|
w2 = torch.einsum(
|
||||||
|
'i j k l, j r, i p -> p r k l',
|
||||||
|
self.lokr_t2,
|
||||||
|
self.lokr_w2_b,
|
||||||
|
self.lokr_w2_a
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
||||||
|
return w2 * (self.alpha / self.rankb)
|
||||||
|
else:
|
||||||
|
return self.lokr_w2
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
diff = torch.kron(self.w1, self.w2)
|
||||||
|
return w + diff.reshape(w.shape).to(w)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +83,20 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
in_dim = weight.shape[1:].numel()
|
||||||
|
out1, out2 = factorization(out_dim, rank)
|
||||||
|
in1, in2 = factorization(in_dim, rank)
|
||||||
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
||||||
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
||||||
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
||||||
|
torch.nn.init.constant_(mat1, 0.0)
|
||||||
|
return LokrDiff(
|
||||||
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
|
@ -3,7 +3,58 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
||||||
|
|
||||||
|
|
||||||
|
class OFTDiff(WeightAdapterTrainBase):
|
||||||
|
def __init__(self, weights):
|
||||||
|
super().__init__()
|
||||||
|
# Unpack weights tuple from LoHaAdapter
|
||||||
|
blocks, rescale, alpha, _ = weights
|
||||||
|
|
||||||
|
# Create trainable parameters
|
||||||
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
||||||
|
if rescale is not None:
|
||||||
|
self.rescale = torch.nn.Parameter(rescale)
|
||||||
|
self.rescaled = True
|
||||||
|
else:
|
||||||
|
self.rescaled = False
|
||||||
|
self.block_num, self.block_size, _ = blocks.shape
|
||||||
|
self.constraint = float(alpha)
|
||||||
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||||
|
|
||||||
|
def __call__(self, w):
|
||||||
|
org_dtype = w.dtype
|
||||||
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
||||||
|
|
||||||
|
## generate r
|
||||||
|
# for Q = -Q^T
|
||||||
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
||||||
|
normed_q = q
|
||||||
|
if self.constraint:
|
||||||
|
q_norm = torch.norm(q) + 1e-8
|
||||||
|
if q_norm > self.constraint:
|
||||||
|
normed_q = q * self.constraint / q_norm
|
||||||
|
# use float() to prevent unsupported type
|
||||||
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
|
|
||||||
|
## Apply chunked matmul on weight
|
||||||
|
_, *shape = w.shape
|
||||||
|
org_weight = w.to(dtype=r.dtype)
|
||||||
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
||||||
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
||||||
|
weight = torch.einsum(
|
||||||
|
"k n m, k n ... -> k m ...",
|
||||||
|
r,
|
||||||
|
org_weight,
|
||||||
|
).flatten(0, 1)
|
||||||
|
if self.rescaled:
|
||||||
|
weight = self.rescale * weight
|
||||||
|
return weight.to(org_dtype)
|
||||||
|
|
||||||
|
def passive_memory_usage(self):
|
||||||
|
"""Calculates memory usage of the trainable parameters."""
|
||||||
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||||
|
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
class OFTAdapter(WeightAdapterBase):
|
||||||
@ -13,6 +64,18 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
self.loaded_keys = loaded_keys
|
self.loaded_keys = loaded_keys
|
||||||
self.weights = weights
|
self.weights = weights
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||||
|
out_dim = weight.shape[0]
|
||||||
|
block_size, block_num = factorization(out_dim, rank)
|
||||||
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
||||||
|
return OFTDiff(
|
||||||
|
(block, None, alpha, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_train(self):
|
||||||
|
return OFTDiff(self.weights)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
cls,
|
cls,
|
||||||
@ -60,6 +123,8 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
blocks = v[0]
|
blocks = v[0]
|
||||||
rescale = v[1]
|
rescale = v[1]
|
||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
|
if alpha is None:
|
||||||
|
alpha = 0
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
|
@ -20,7 +20,7 @@ import folder_paths
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.comfy_types.node_typing import IO
|
from comfy.comfy_types.node_typing import IO
|
||||||
from comfy.weight_adapter import adapters
|
from comfy.weight_adapter import adapters, adapter_maps
|
||||||
|
|
||||||
|
|
||||||
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
def make_batch_extra_option_dict(d, indicies, full_size=None):
|
||||||
@ -39,13 +39,13 @@ def make_batch_extra_option_dict(d, indicies, full_size=None):
|
|||||||
|
|
||||||
|
|
||||||
class TrainSampler(comfy.samplers.Sampler):
|
class TrainSampler(comfy.samplers.Sampler):
|
||||||
|
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
||||||
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
|
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.loss_callback = loss_callback
|
self.loss_callback = loss_callback
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.total_steps = total_steps
|
self.total_steps = total_steps
|
||||||
|
self.grad_acc = grad_acc
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.training_dtype = training_dtype
|
self.training_dtype = training_dtype
|
||||||
|
|
||||||
@ -92,6 +92,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.loss_callback(loss.item())
|
self.loss_callback(loss.item())
|
||||||
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
||||||
|
|
||||||
|
if (i+1) % self.grad_acc == 0:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -419,6 +420,16 @@ class TrainLoraNode:
|
|||||||
"tooltip": "The batch size to use for training.",
|
"tooltip": "The batch size to use for training.",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
"grad_accumulation_steps": (
|
||||||
|
IO.INT,
|
||||||
|
{
|
||||||
|
"default": 1,
|
||||||
|
"min": 1,
|
||||||
|
"max": 1024,
|
||||||
|
"step": 1,
|
||||||
|
"tooltip": "The number of gradient accumulation steps to use for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"steps": (
|
"steps": (
|
||||||
IO.INT,
|
IO.INT,
|
||||||
{
|
{
|
||||||
@ -478,6 +489,17 @@ class TrainLoraNode:
|
|||||||
["bf16", "fp32"],
|
["bf16", "fp32"],
|
||||||
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
{"default": "bf16", "tooltip": "The dtype to use for lora."},
|
||||||
),
|
),
|
||||||
|
"algorithm": (
|
||||||
|
list(adapter_maps.keys()),
|
||||||
|
{"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."},
|
||||||
|
),
|
||||||
|
"gradient_checkpointing": (
|
||||||
|
IO.BOOLEAN,
|
||||||
|
{
|
||||||
|
"default": True,
|
||||||
|
"tooltip": "Use gradient checkpointing for training.",
|
||||||
|
}
|
||||||
|
),
|
||||||
"existing_lora": (
|
"existing_lora": (
|
||||||
folder_paths.get_filename_list("loras") + ["[None]"],
|
folder_paths.get_filename_list("loras") + ["[None]"],
|
||||||
{
|
{
|
||||||
@ -501,6 +523,7 @@ class TrainLoraNode:
|
|||||||
positive,
|
positive,
|
||||||
batch_size,
|
batch_size,
|
||||||
steps,
|
steps,
|
||||||
|
grad_accumulation_steps,
|
||||||
learning_rate,
|
learning_rate,
|
||||||
rank,
|
rank,
|
||||||
optimizer,
|
optimizer,
|
||||||
@ -508,6 +531,8 @@ class TrainLoraNode:
|
|||||||
seed,
|
seed,
|
||||||
training_dtype,
|
training_dtype,
|
||||||
lora_dtype,
|
lora_dtype,
|
||||||
|
algorithm,
|
||||||
|
gradient_checkpointing,
|
||||||
existing_lora,
|
existing_lora,
|
||||||
):
|
):
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
@ -558,10 +583,8 @@ class TrainLoraNode:
|
|||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# If no existing adapter found, use LoRA
|
|
||||||
# We will add algo option in the future
|
|
||||||
existing_adapter = None
|
existing_adapter = None
|
||||||
adapter_cls = adapters[0]
|
adapter_cls = adapter_maps[algorithm]
|
||||||
|
|
||||||
if existing_adapter is not None:
|
if existing_adapter is not None:
|
||||||
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
train_adapter = existing_adapter.to_train().to(lora_dtype)
|
||||||
@ -615,6 +638,7 @@ class TrainLoraNode:
|
|||||||
criterion = torch.nn.SmoothL1Loss()
|
criterion = torch.nn.SmoothL1Loss()
|
||||||
|
|
||||||
# setup models
|
# setup models
|
||||||
|
if gradient_checkpointing:
|
||||||
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model):
|
||||||
patch(m)
|
patch(m)
|
||||||
mp.model.requires_grad_(False)
|
mp.model.requires_grad_(False)
|
||||||
@ -629,7 +653,8 @@ class TrainLoraNode:
|
|||||||
optimizer,
|
optimizer,
|
||||||
loss_callback=loss_callback,
|
loss_callback=loss_callback,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
total_steps=steps,
|
grad_acc=grad_accumulation_steps,
|
||||||
|
total_steps=steps*grad_accumulation_steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype
|
training_dtype=dtype
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user