[Training Node] algo support, grad acc, optional grad ckpt (#9015)

* Add factorization utils for lokr

* Add lokr train impl

* Add loha train impl

* Add adapter map for algo selection

* Add optional grad ckpt and algo selection

* Update __init__.py

* correct key name for loha

* Use custom fwd/bwd func and better init for loha

* Support gradient accumulation

* Fix bugs of loha

* use more stable init

* Add OFT training

* linting
This commit is contained in:
Kohaku-Blueleaf
2025-07-24 08:57:27 +08:00
committed by GitHub
parent e729a5cc11
commit eb2f78b4e0
6 changed files with 372 additions and 15 deletions

View File

@@ -15,9 +15,20 @@ adapters: list[type[WeightAdapterBase]] = [
OFTAdapter,
BOFTAdapter,
]
adapter_maps: dict[str, type[WeightAdapterBase]] = {
"LoRA": LoRAAdapter,
"LoHa": LoHaAdapter,
"LoKr": LoKrAdapter,
"OFT": OFTAdapter,
## We disable not implemented algo for now
# "GLoRA": GLoRAAdapter,
# "BOFT": BOFTAdapter,
}
__all__ = [
"WeightAdapterBase",
"WeightAdapterTrainBase",
"adapters"
"adapters",
"adapter_maps",
] + [a.__name__ for a in adapters]