Fix the bugs in OFT/BOFT moule (#7909)

* Correct calculate_weight and load for OFT

* Correct calculate_weight and loading for BOFT
This commit is contained in:
Kohaku-Blueleaf
2025-05-03 01:12:37 +08:00
committed by GitHub
parent d9a87c1e6a
commit 2ab9618732
2 changed files with 28 additions and 26 deletions

View File

@@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
blocks = lora[blocks_name]
if blocks.ndim == 3:
loaded_keys.add(blocks_name)
else:
blocks = None
if blocks is None:
return None
rescale = None
if rescale_name in lora.keys():
rescale = lora[rescale_name]
loaded_keys.add(rescale_name)
if blocks is not None:
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
else:
return None
weights = (blocks, rescale, alpha, dora_scale)
return cls(loaded_keys, weights)
def calculate_weight(
self,
@@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
normed_q = q * alpha / q_norm
# use float() to prevent unsupported type in .inverse()
r = (I + normed_q) @ (I - normed_q).float().inverse()
r = r.to(original_weight)
r = r.to(weight)
_, *shape = weight.shape
lora_diff = torch.einsum(
"k n m, k n ... -> k m ...",
(r * strength) - strength * I,
original_weight,
)
weight.view(block_num, block_size, *shape),
).view(-1, *shape)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
return weight