mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-09-12 04:27:21 +00:00
Fix the bugs in OFT/BOFT moule (#7909)
* Correct calculate_weight and load for OFT * Correct calculate_weight and loading for BOFT
This commit is contained in:
@@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
|
||||
blocks = lora[blocks_name]
|
||||
if blocks.ndim == 3:
|
||||
loaded_keys.add(blocks_name)
|
||||
else:
|
||||
blocks = None
|
||||
if blocks is None:
|
||||
return None
|
||||
|
||||
rescale = None
|
||||
if rescale_name in lora.keys():
|
||||
rescale = lora[rescale_name]
|
||||
loaded_keys.add(rescale_name)
|
||||
|
||||
if blocks is not None:
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
else:
|
||||
return None
|
||||
weights = (blocks, rescale, alpha, dora_scale)
|
||||
return cls(loaded_keys, weights)
|
||||
|
||||
def calculate_weight(
|
||||
self,
|
||||
@@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
|
||||
normed_q = q * alpha / q_norm
|
||||
# use float() to prevent unsupported type in .inverse()
|
||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||
r = r.to(original_weight)
|
||||
r = r.to(weight)
|
||||
_, *shape = weight.shape
|
||||
lora_diff = torch.einsum(
|
||||
"k n m, k n ... -> k m ...",
|
||||
(r * strength) - strength * I,
|
||||
original_weight,
|
||||
)
|
||||
weight.view(block_num, block_size, *shape),
|
||||
).view(-1, *shape)
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
weight += function((strength * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||
return weight
|
||||
|
Reference in New Issue
Block a user