Make lora code a bit cleaner.

This commit is contained in:
comfyanonymous
2023-12-09 14:15:09 -05:00
parent 9e411073e9
commit cb63e230b4
2 changed files with 18 additions and 10 deletions

View File

@@ -43,7 +43,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = (lora[A_name], lora[B_name], alpha, mid)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@@ -64,7 +64,7 @@ def load_lora(lora, to_load):
loaded_keys.add(hada_t1_name)
loaded_keys.add(hada_t2_name)
patch_dict[to_load[x]] = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2)
patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
loaded_keys.add(hada_w1_a_name)
loaded_keys.add(hada_w1_b_name)
loaded_keys.add(hada_w2_a_name)
@@ -116,7 +116,7 @@ def load_lora(lora, to_load):
loaded_keys.add(lokr_t2_name)
if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
patch_dict[to_load[x]] = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2)
patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
w_norm_name = "{}.w_norm".format(x)
@@ -126,21 +126,21 @@ def load_lora(lora, to_load):
if w_norm is not None:
loaded_keys.add(w_norm_name)
patch_dict[to_load[x]] = (w_norm,)
patch_dict[to_load[x]] = ("diff", (w_norm,))
if b_norm is not None:
loaded_keys.add(b_norm_name)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (b_norm,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
diff_name = "{}.diff".format(x)
diff_weight = lora.get(diff_name, None)
if diff_weight is not None:
patch_dict[to_load[x]] = (diff_weight,)
patch_dict[to_load[x]] = ("diff", (diff_weight,))
loaded_keys.add(diff_name)
diff_bias_name = "{}.diff_b".format(x)
diff_bias = lora.get(diff_bias_name, None)
if diff_bias is not None:
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = (diff_bias,)
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
for x in lora.keys():