mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-06-09 15:47:14 +00:00
Print errors and continue when lora weights are not compatible.
This commit is contained in:
parent
4760c29380
commit
0115018695
16
comfy/sd.py
16
comfy/sd.py
@ -376,7 +376,10 @@ class ModelPatcher:
|
|||||||
mat3 = v[3].float().to(weight.device)
|
mat3 = v[3].float().to(weight.device)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
elif len(v) == 8: #lokr
|
elif len(v) == 8: #lokr
|
||||||
w1 = v[0]
|
w1 = v[0]
|
||||||
w2 = v[1]
|
w2 = v[1]
|
||||||
@ -407,7 +410,10 @@ class ModelPatcher:
|
|||||||
if v[2] is not None and dim is not None:
|
if v[2] is not None and dim is not None:
|
||||||
alpha *= v[2] / dim
|
alpha *= v[2] / dim
|
||||||
|
|
||||||
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
else: #loha
|
else: #loha
|
||||||
w1a = v[0]
|
w1a = v[0]
|
||||||
w1b = v[1]
|
w1b = v[1]
|
||||||
@ -424,7 +430,11 @@ class ModelPatcher:
|
|||||||
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
m1 = torch.mm(w1a.float().to(weight.device), w1b.float().to(weight.device))
|
||||||
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
m2 = torch.mm(w2a.float().to(weight.device), w2b.float().to(weight.device))
|
||||||
|
|
||||||
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
try:
|
||||||
|
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
|
||||||
|
except Exception as e:
|
||||||
|
print("ERROR", key, e)
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def unpatch_model(self):
|
def unpatch_model(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user