mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 16:26:39 +00:00
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import comfy.model_management
|
|
from .base import (
|
|
WeightAdapterBase,
|
|
WeightAdapterTrainBase,
|
|
weight_decompose,
|
|
factorization,
|
|
)
|
|
|
|
|
|
class LokrDiff(WeightAdapterTrainBase):
|
|
def __init__(self, weights):
|
|
super().__init__()
|
|
(lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights
|
|
self.use_tucker = False
|
|
if lokr_w1_a is not None:
|
|
_, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1]
|
|
rank_a, _ = lokr_w1_b.shape[0], lokr_w1_b.shape[1]
|
|
self.lokr_w1_a = torch.nn.Parameter(lokr_w1_a)
|
|
self.lokr_w1_b = torch.nn.Parameter(lokr_w1_b)
|
|
self.w1_rebuild = True
|
|
self.ranka = rank_a
|
|
|
|
if lokr_w2_a is not None:
|
|
_, rank_b = lokr_w2_a.shape[0], lokr_w2_a.shape[1]
|
|
rank_b, _ = lokr_w2_b.shape[0], lokr_w2_b.shape[1]
|
|
self.lokr_w2_a = torch.nn.Parameter(lokr_w2_a)
|
|
self.lokr_w2_b = torch.nn.Parameter(lokr_w2_b)
|
|
if lokr_t2 is not None:
|
|
self.use_tucker = True
|
|
self.lokr_t2 = torch.nn.Parameter(lokr_t2)
|
|
self.w2_rebuild = True
|
|
self.rankb = rank_b
|
|
|
|
if lokr_w1 is not None:
|
|
self.lokr_w1 = torch.nn.Parameter(lokr_w1)
|
|
self.w1_rebuild = False
|
|
|
|
if lokr_w2 is not None:
|
|
self.lokr_w2 = torch.nn.Parameter(lokr_w2)
|
|
self.w2_rebuild = False
|
|
|
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
|
|
|
@property
|
|
def w1(self):
|
|
if self.w1_rebuild:
|
|
return (self.lokr_w1_a @ self.lokr_w1_b) * (self.alpha / self.ranka)
|
|
else:
|
|
return self.lokr_w1
|
|
|
|
@property
|
|
def w2(self):
|
|
if self.w2_rebuild:
|
|
if self.use_tucker:
|
|
w2 = torch.einsum(
|
|
'i j k l, j r, i p -> p r k l',
|
|
self.lokr_t2,
|
|
self.lokr_w2_b,
|
|
self.lokr_w2_a
|
|
)
|
|
else:
|
|
w2 = self.lokr_w2_a @ self.lokr_w2_b
|
|
return w2 * (self.alpha / self.rankb)
|
|
else:
|
|
return self.lokr_w2
|
|
|
|
def __call__(self, w):
|
|
diff = torch.kron(self.w1, self.w2)
|
|
return w + diff.reshape(w.shape).to(w)
|
|
|
|
def passive_memory_usage(self):
|
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
|
|
|
|
|
class LoKrAdapter(WeightAdapterBase):
|
|
name = "lokr"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
|
out_dim = weight.shape[0]
|
|
in_dim = weight.shape[1:].numel()
|
|
out1, out2 = factorization(out_dim, rank)
|
|
in1, in2 = factorization(in_dim, rank)
|
|
mat1 = torch.empty(out1, in1, device=weight.device, dtype=weight.dtype)
|
|
mat2 = torch.empty(out2, in2, device=weight.device, dtype=weight.dtype)
|
|
torch.nn.init.kaiming_uniform_(mat2, a=5**0.5)
|
|
torch.nn.init.constant_(mat1, 0.0)
|
|
return LokrDiff(
|
|
(mat1, mat2, alpha, None, None, None, None, None, None)
|
|
)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["LoKrAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
lokr_w1_name = "{}.lokr_w1".format(x)
|
|
lokr_w2_name = "{}.lokr_w2".format(x)
|
|
lokr_w1_a_name = "{}.lokr_w1_a".format(x)
|
|
lokr_w1_b_name = "{}.lokr_w1_b".format(x)
|
|
lokr_t2_name = "{}.lokr_t2".format(x)
|
|
lokr_w2_a_name = "{}.lokr_w2_a".format(x)
|
|
lokr_w2_b_name = "{}.lokr_w2_b".format(x)
|
|
|
|
lokr_w1 = None
|
|
if lokr_w1_name in lora.keys():
|
|
lokr_w1 = lora[lokr_w1_name]
|
|
loaded_keys.add(lokr_w1_name)
|
|
|
|
lokr_w2 = None
|
|
if lokr_w2_name in lora.keys():
|
|
lokr_w2 = lora[lokr_w2_name]
|
|
loaded_keys.add(lokr_w2_name)
|
|
|
|
lokr_w1_a = None
|
|
if lokr_w1_a_name in lora.keys():
|
|
lokr_w1_a = lora[lokr_w1_a_name]
|
|
loaded_keys.add(lokr_w1_a_name)
|
|
|
|
lokr_w1_b = None
|
|
if lokr_w1_b_name in lora.keys():
|
|
lokr_w1_b = lora[lokr_w1_b_name]
|
|
loaded_keys.add(lokr_w1_b_name)
|
|
|
|
lokr_w2_a = None
|
|
if lokr_w2_a_name in lora.keys():
|
|
lokr_w2_a = lora[lokr_w2_a_name]
|
|
loaded_keys.add(lokr_w2_a_name)
|
|
|
|
lokr_w2_b = None
|
|
if lokr_w2_b_name in lora.keys():
|
|
lokr_w2_b = lora[lokr_w2_b_name]
|
|
loaded_keys.add(lokr_w2_b_name)
|
|
|
|
lokr_t2 = None
|
|
if lokr_t2_name in lora.keys():
|
|
lokr_t2 = lora[lokr_t2_name]
|
|
loaded_keys.add(lokr_t2_name)
|
|
|
|
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
|
|
weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
|
|
return cls(loaded_keys, weights)
|
|
else:
|
|
return None
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
w1 = v[0]
|
|
w2 = v[1]
|
|
w1_a = v[3]
|
|
w1_b = v[4]
|
|
w2_a = v[5]
|
|
w2_b = v[6]
|
|
t2 = v[7]
|
|
dora_scale = v[8]
|
|
dim = None
|
|
|
|
if w1 is None:
|
|
dim = w1_b.shape[0]
|
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
else:
|
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
|
|
if w2 is None:
|
|
dim = w2_b.shape[0]
|
|
if t2 is None:
|
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
else:
|
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
else:
|
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
|
|
if len(w2.shape) == 4:
|
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
if v[2] is not None and dim is not None:
|
|
alpha = v[2] / dim
|
|
else:
|
|
alpha = 1.0
|
|
|
|
try:
|
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
else:
|
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|