mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-07-27 08:16:44 +00:00
* Add factorization utils for lokr * Add lokr train impl * Add loha train impl * Add adapter map for algo selection * Add optional grad ckpt and algo selection * Update __init__.py * correct key name for loha * Use custom fwd/bwd func and better init for loha * Support gradient accumulation * Fix bugs of loha * use more stable init * Add OFT training * linting
162 lines
5.3 KiB
Python
162 lines
5.3 KiB
Python
import logging
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import comfy.model_management
|
|
from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization
|
|
|
|
|
|
class OFTDiff(WeightAdapterTrainBase):
|
|
def __init__(self, weights):
|
|
super().__init__()
|
|
# Unpack weights tuple from LoHaAdapter
|
|
blocks, rescale, alpha, _ = weights
|
|
|
|
# Create trainable parameters
|
|
self.oft_blocks = torch.nn.Parameter(blocks)
|
|
if rescale is not None:
|
|
self.rescale = torch.nn.Parameter(rescale)
|
|
self.rescaled = True
|
|
else:
|
|
self.rescaled = False
|
|
self.block_num, self.block_size, _ = blocks.shape
|
|
self.constraint = float(alpha)
|
|
self.alpha = torch.nn.Parameter(torch.tensor(alpha), requires_grad=False)
|
|
|
|
def __call__(self, w):
|
|
org_dtype = w.dtype
|
|
I = torch.eye(self.block_size, device=self.oft_blocks.device)
|
|
|
|
## generate r
|
|
# for Q = -Q^T
|
|
q = self.oft_blocks - self.oft_blocks.transpose(1, 2)
|
|
normed_q = q
|
|
if self.constraint:
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > self.constraint:
|
|
normed_q = q * self.constraint / q_norm
|
|
# use float() to prevent unsupported type
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
|
|
## Apply chunked matmul on weight
|
|
_, *shape = w.shape
|
|
org_weight = w.to(dtype=r.dtype)
|
|
org_weight = org_weight.unflatten(0, (self.block_num, self.block_size))
|
|
# Init R=0, so add I on it to ensure the output of step0 is original model output
|
|
weight = torch.einsum(
|
|
"k n m, k n ... -> k m ...",
|
|
r,
|
|
org_weight,
|
|
).flatten(0, 1)
|
|
if self.rescaled:
|
|
weight = self.rescale * weight
|
|
return weight.to(org_dtype)
|
|
|
|
def passive_memory_usage(self):
|
|
"""Calculates memory usage of the trainable parameters."""
|
|
return sum(param.numel() * param.element_size() for param in self.parameters())
|
|
|
|
|
|
class OFTAdapter(WeightAdapterBase):
|
|
name = "oft"
|
|
|
|
def __init__(self, loaded_keys, weights):
|
|
self.loaded_keys = loaded_keys
|
|
self.weights = weights
|
|
|
|
@classmethod
|
|
def create_train(cls, weight, rank=1, alpha=1.0):
|
|
out_dim = weight.shape[0]
|
|
block_size, block_num = factorization(out_dim, rank)
|
|
block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=weight.dtype)
|
|
return OFTDiff(
|
|
(block, None, alpha, None)
|
|
)
|
|
|
|
def to_train(self):
|
|
return OFTDiff(self.weights)
|
|
|
|
@classmethod
|
|
def load(
|
|
cls,
|
|
x: str,
|
|
lora: dict[str, torch.Tensor],
|
|
alpha: float,
|
|
dora_scale: torch.Tensor,
|
|
loaded_keys: set[str] = None,
|
|
) -> Optional["OFTAdapter"]:
|
|
if loaded_keys is None:
|
|
loaded_keys = set()
|
|
blocks_name = "{}.oft_blocks".format(x)
|
|
rescale_name = "{}.rescale".format(x)
|
|
|
|
blocks = None
|
|
if blocks_name in lora.keys():
|
|
blocks = lora[blocks_name]
|
|
if blocks.ndim == 3:
|
|
loaded_keys.add(blocks_name)
|
|
else:
|
|
blocks = None
|
|
if blocks is None:
|
|
return None
|
|
|
|
rescale = None
|
|
if rescale_name in lora.keys():
|
|
rescale = lora[rescale_name]
|
|
loaded_keys.add(rescale_name)
|
|
|
|
weights = (blocks, rescale, alpha, dora_scale)
|
|
return cls(loaded_keys, weights)
|
|
|
|
def calculate_weight(
|
|
self,
|
|
weight,
|
|
key,
|
|
strength,
|
|
strength_model,
|
|
offset,
|
|
function,
|
|
intermediate_dtype=torch.float32,
|
|
original_weight=None,
|
|
):
|
|
v = self.weights
|
|
blocks = v[0]
|
|
rescale = v[1]
|
|
alpha = v[2]
|
|
if alpha is None:
|
|
alpha = 0
|
|
dora_scale = v[3]
|
|
|
|
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
|
if rescale is not None:
|
|
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
|
|
|
block_num, block_size, *_ = blocks.shape
|
|
|
|
try:
|
|
# Get r
|
|
I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
|
|
# for Q = -Q^T
|
|
q = blocks - blocks.transpose(1, 2)
|
|
normed_q = q
|
|
if alpha > 0: # alpha in oft/boft is for constraint
|
|
q_norm = torch.norm(q) + 1e-8
|
|
if q_norm > alpha:
|
|
normed_q = q * alpha / q_norm
|
|
# use float() to prevent unsupported type in .inverse()
|
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
|
r = r.to(weight)
|
|
_, *shape = weight.shape
|
|
lora_diff = torch.einsum(
|
|
"k n m, k n ... -> k m ...",
|
|
(r * strength) - strength * I,
|
|
weight.view(block_num, block_size, *shape),
|
|
).view(-1, *shape)
|
|
if dora_scale is not None:
|
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
|
else:
|
|
weight += function((strength * lora_diff).type(weight.dtype))
|
|
except Exception as e:
|
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
|
return weight
|