Try to fix memory issue with lora.

This commit is contained in:
comfyanonymous
2023-07-22 21:26:45 -04:00
parent 67be7eb81d
commit 22f29d66ca
2 changed files with 12 additions and 5 deletions

View File

@@ -338,7 +338,7 @@ class ModelPatcher:
sd.pop(k)
return sd
def patch_model(self):
def patch_model(self, device_to=None):
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
@@ -350,10 +350,13 @@ class ModelPatcher:
if key not in self.backup:
self.backup[key] = weight.to(self.offload_device)
temp_weight = weight.to(torch.float32, copy=True)
if device_to is not None:
temp_weight = weight.float().to(device_to, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
set_attr(self.model, key, out_weight)
del weight
del temp_weight
return self.model
def calculate_weight(self, patches, weight, key):