mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-08 23:27:14 +00:00
Fix the bugs in OFT/BOFT moule (#7909)
* Correct calculate_weight and load for OFT * Correct calculate_weight and loading for BOFT
This commit is contained in:
parent
d9a87c1e6a
commit
2ab9618732
@ -24,7 +24,7 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
) -> Optional["BOFTAdapter"]:
|
) -> Optional["BOFTAdapter"]:
|
||||||
if loaded_keys is None:
|
if loaded_keys is None:
|
||||||
loaded_keys = set()
|
loaded_keys = set()
|
||||||
blocks_name = "{}.boft_blocks".format(x)
|
blocks_name = "{}.oft_blocks".format(x)
|
||||||
rescale_name = "{}.rescale".format(x)
|
rescale_name = "{}.rescale".format(x)
|
||||||
|
|
||||||
blocks = None
|
blocks = None
|
||||||
@ -32,17 +32,18 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
blocks = lora[blocks_name]
|
blocks = lora[blocks_name]
|
||||||
if blocks.ndim == 4:
|
if blocks.ndim == 4:
|
||||||
loaded_keys.add(blocks_name)
|
loaded_keys.add(blocks_name)
|
||||||
|
else:
|
||||||
|
blocks = None
|
||||||
|
if blocks is None:
|
||||||
|
return None
|
||||||
|
|
||||||
rescale = None
|
rescale = None
|
||||||
if rescale_name in lora.keys():
|
if rescale_name in lora.keys():
|
||||||
rescale = lora[rescale_name]
|
rescale = lora[rescale_name]
|
||||||
loaded_keys.add(rescale_name)
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
if blocks is not None:
|
|
||||||
weights = (blocks, rescale, alpha, dora_scale)
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
return cls(loaded_keys, weights)
|
return cls(loaded_keys, weights)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
@ -71,7 +72,7 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
# Get r
|
# Get r
|
||||||
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
|
||||||
# for Q = -Q^T
|
# for Q = -Q^T
|
||||||
q = blocks - blocks.transpose(1, 2)
|
q = blocks - blocks.transpose(-1, -2)
|
||||||
normed_q = q
|
normed_q = q
|
||||||
if alpha > 0: # alpha in boft/bboft is for constraint
|
if alpha > 0: # alpha in boft/bboft is for constraint
|
||||||
q_norm = torch.norm(q) + 1e-8
|
q_norm = torch.norm(q) + 1e-8
|
||||||
@ -79,9 +80,8 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
normed_q = q * alpha / q_norm
|
normed_q = q * alpha / q_norm
|
||||||
# use float() to prevent unsupported type in .inverse()
|
# use float() to prevent unsupported type in .inverse()
|
||||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
r = r.to(original_weight)
|
r = r.to(weight)
|
||||||
|
inp = org = weight
|
||||||
inp = org = original_weight
|
|
||||||
|
|
||||||
r_b = boft_b//2
|
r_b = boft_b//2
|
||||||
for i in range(boft_m):
|
for i in range(boft_m):
|
||||||
@ -91,14 +91,14 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
if strength != 1:
|
if strength != 1:
|
||||||
bi = bi * strength + (1-strength) * I
|
bi = bi * strength + (1-strength) * I
|
||||||
inp = (
|
inp = (
|
||||||
inp.unflatten(-1, (-1, g, k))
|
inp.unflatten(0, (-1, g, k))
|
||||||
.transpose(-2, -1)
|
.transpose(1, 2)
|
||||||
.flatten(-3)
|
.flatten(0, 2)
|
||||||
.unflatten(-1, (-1, boft_b))
|
.unflatten(0, (-1, boft_b))
|
||||||
)
|
)
|
||||||
inp = torch.einsum("b n m, b n ... -> b m ...", inp, bi)
|
inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
|
||||||
inp = (
|
inp = (
|
||||||
inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3)
|
inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
@ -109,7 +109,7 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
@ -32,17 +32,18 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
blocks = lora[blocks_name]
|
blocks = lora[blocks_name]
|
||||||
if blocks.ndim == 3:
|
if blocks.ndim == 3:
|
||||||
loaded_keys.add(blocks_name)
|
loaded_keys.add(blocks_name)
|
||||||
|
else:
|
||||||
|
blocks = None
|
||||||
|
if blocks is None:
|
||||||
|
return None
|
||||||
|
|
||||||
rescale = None
|
rescale = None
|
||||||
if rescale_name in lora.keys():
|
if rescale_name in lora.keys():
|
||||||
rescale = lora[rescale_name]
|
rescale = lora[rescale_name]
|
||||||
loaded_keys.add(rescale_name)
|
loaded_keys.add(rescale_name)
|
||||||
|
|
||||||
if blocks is not None:
|
|
||||||
weights = (blocks, rescale, alpha, dora_scale)
|
weights = (blocks, rescale, alpha, dora_scale)
|
||||||
return cls(loaded_keys, weights)
|
return cls(loaded_keys, weights)
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def calculate_weight(
|
def calculate_weight(
|
||||||
self,
|
self,
|
||||||
@ -79,16 +80,17 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
normed_q = q * alpha / q_norm
|
normed_q = q * alpha / q_norm
|
||||||
# use float() to prevent unsupported type in .inverse()
|
# use float() to prevent unsupported type in .inverse()
|
||||||
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
r = (I + normed_q) @ (I - normed_q).float().inverse()
|
||||||
r = r.to(original_weight)
|
r = r.to(weight)
|
||||||
|
_, *shape = weight.shape
|
||||||
lora_diff = torch.einsum(
|
lora_diff = torch.einsum(
|
||||||
"k n m, k n ... -> k m ...",
|
"k n m, k n ... -> k m ...",
|
||||||
(r * strength) - strength * I,
|
(r * strength) - strength * I,
|
||||||
original_weight,
|
weight.view(block_num, block_size, *shape),
|
||||||
)
|
).view(-1, *shape)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
Loading…
x
Reference in New Issue
Block a user