Replace prints with logging and add --verbose argument.

This commit is contained in:
comfyanonymous
2024-03-10 11:37:08 -04:00
parent 4656273e72
commit 65397ce601
12 changed files with 90 additions and 65 deletions

View File

@@ -1,6 +1,7 @@
import torch
import copy
import inspect
import logging
import comfy.utils
import comfy.model_management
@@ -187,7 +188,7 @@ class ModelPatcher:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
print("could not patch. key doesn't exist in model:", key)
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
weight = model_sd[key]
@@ -236,7 +237,7 @@ class ModelPatcher:
w1 = v[0]
if alpha != 0.0:
if w1.shape != weight.shape:
print("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += alpha * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype)
elif patch_type == "lora": #lora/locon
@@ -252,7 +253,7 @@ class ModelPatcher:
try:
weight += (alpha * torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1))).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
@@ -291,7 +292,7 @@ class ModelPatcher:
try:
weight += alpha * torch.kron(w1, w2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
@@ -320,7 +321,7 @@ class ModelPatcher:
try:
weight += (alpha * m1 * m2).reshape(weight.shape).type(weight.dtype)
except Exception as e:
print("ERROR", key, e)
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha *= v[4] / v[0].shape[0]
@@ -330,9 +331,12 @@ class ModelPatcher:
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
try:
weight += ((torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)) * alpha).reshape(weight.shape).type(weight.dtype)
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
print("patch type not recognized", patch_type, key)
logging.warning("patch type not recognized {} {}".format(patch_type, key))
return weight