mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-11 12:06:23 +00:00
[Training Node] algo support, grad acc, optional grad ckpt (#9015)
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
This commit is contained in:
@@ -3,7 +3,120 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import comfy.model_management
|
||||
from .base import WeightAdapterBase, weight_decompose
|
||||
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose
|
||||
|
||||
|
||||
class HadaWeight(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(w1d, w1u, w2d, w2u, scale)
|
||||
diff_weight = ((w1u @ w1d) * (w2u @ w2d)) * scale
|
||||
return diff_weight
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(w1d, w1u, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
temp = grad_out * (w2u @ w2d)
|
||||
grad_w1u = temp @ w1d.T
|
||||
grad_w1d = w1u.T @ temp
|
||||
|
||||
temp = grad_out * (w1u @ w1d)
|
||||
grad_w2u = temp @ w2d.T
|
||||
grad_w2d = w2u.T @ temp
|
||||
|
||||
del temp
|
||||
return grad_w1u, grad_w1d, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class HadaWeightTucker(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, t1, w1u, w1d, t2, w2u, w2d, scale=torch.tensor(1)):
|
||||
ctx.save_for_backward(t1, w1d, w1u, t2, w2d, w2u, scale)
|
||||
|
||||
rebuild1 = torch.einsum("i j ..., j r, i p -> p r ...", t1, w1d, w1u)
|
||||
rebuild2 = torch.einsum("i j ..., j r, i p -> p r ...", t2, w2d, w2u)
|
||||
|
||||
return rebuild1 * rebuild2 * scale
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
(t1, w1d, w1u, t2, w2d, w2u, scale) = ctx.saved_tensors
|
||||
grad_out = grad_out * scale
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t2, w2d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w2u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w1u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w1u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w1d = torch.einsum("i r ..., i j ... -> r j", t1, grad_temp)
|
||||
grad_t1 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w1d.T)
|
||||
del grad_temp
|
||||
|
||||
temp = torch.einsum("i j ..., j r -> i r ...", t1, w1d)
|
||||
rebuild = torch.einsum("i j ..., i r -> r j ...", temp, w1u)
|
||||
|
||||
grad_w = rebuild * grad_out
|
||||
del rebuild
|
||||
|
||||
grad_w2u = torch.einsum("r j ..., i j ... -> r i", temp, grad_w)
|
||||
grad_temp = torch.einsum("i j ..., i r -> r j ...", grad_w, w2u.T)
|
||||
del grad_w, temp
|
||||
|
||||
grad_w2d = torch.einsum("i r ..., i j ... -> r j", t2, grad_temp)
|
||||
grad_t2 = torch.einsum("i j ..., j r -> i r ...", grad_temp, w2d.T)
|
||||
del grad_temp
|
||||
return grad_t1, grad_w1u, grad_w1d, grad_t2, grad_w2u, grad_w2d, None
|
||||
|
||||
|
||||
class LohaDiff(WeightAdapterTrainBase):
|
||||
def __init__(self, weights):
|
||||
super().__init__()
|
||||
# Unpack weights tuple from LoHaAdapter
|
||||
w1a, w1b, alpha, w2a, w2b, t1, t2, _ = weights
|
||||
|
||||
# Create trainable parameters
|
||||
self.hada_w1_a = torch.nn.Parameter(w1a)
|
||||
self.hada_w1_b = torch.nn.Parameter(w1b)
|
||||
self.hada_w2_a = torch.nn.Parameter(w2a)
|
||||
self.hada_w2_b = torch.nn.Parameter(w2b)
|
||||
|
||||
self.use_tucker = False
|
||||
if t1 is not None and t2 is not None:
|
||||
self.use_tucker = True
|
||||
self.hada_t1 = torch.nn.Parameter(t1)
|
||||
self.hada_t2 = torch.nn.Parameter(t2)
|
||||
else:
|
||||
# Keep the attributes for consistent access
|
||||
self.hada_t1 = None
|
||||
self.hada_t2 = None
|
||||
|
||||
# Store rank and non-trainable alpha
|
||||
self.rank = w1b.shape[0]
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
||||
|
||||
def __call__(self, w):
|
||||
org_dtype = w.dtype
|
||||
|
||||
scale = self.alpha / self.rank
|
||||
if self.use_tucker:
|
||||
diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
else:
|
||||
diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale)
|
||||
|
||||
# Add the scaled difference to the original weight
|
||||
weight = w.to(diff_weight) + diff_weight.reshape(w.shape)
|
||||
|
||||
return weight.to(org_dtype)
|
||||
|
||||
def passive_memory_usage(self):
|
||||
"""Calculates memory usage of the trainable parameters."""
|
||||
return sum(param.numel() * param.element_size() for param in self.parameters())
|
||||
|
||||
|
||||
class LoHaAdapter(WeightAdapterBase):
|
||||
@@ -13,6 +126,25 @@ class LoHaAdapter(WeightAdapterBase):
|
||||
self.loaded_keys = loaded_keys
|
||||
self.weights = weights
|
||||
|
||||
@classmethod
|
||||
def create_train(cls, weight, rank=1, alpha=1.0):
|
||||
out_dim = weight.shape[0]
|
||||
in_dim = weight.shape[1:].numel()
|
||||
mat1 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat1, 0.1)
|
||||
torch.nn.init.constant_(mat2, 0.0)
|
||||
mat3 = torch.empty(out_dim, rank, device=weight.device, dtype=weight.dtype)
|
||||
mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=weight.dtype)
|
||||
torch.nn.init.normal_(mat3, 0.1)
|
||||
torch.nn.init.normal_(mat4, 0.01)
|
||||
return LohaDiff(
|
||||
(mat1, mat2, alpha, mat3, mat4, None, None, None)
|
||||
)
|
||||
|
||||
def to_train(self):
|
||||
return LohaDiff(self.weights)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
|
Reference in New Issue
Block a user