Lower lora ram usage when in normal vram mode.

This commit is contained in:
comfyanonymous
2023-07-16 02:48:09 -04:00
parent 490771b7f4
commit 5f57362613
2 changed files with 19 additions and 13 deletions

View File

@@ -428,11 +428,17 @@ class ModelPatcher:
return weight
def unpatch_model(self):
model_sd = self.model_state_dict()
keys = list(self.backup.keys())
def set_attr(obj, attr, value):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])
setattr(obj, attrs[-1], torch.nn.Parameter(value))
del prev
for k in keys:
model_sd[k][:] = self.backup[k]
del self.backup[k]
set_attr(self.model, k, self.backup[k])
self.backup = {}