Don't unload model weights for non weight patches.

This commit is contained in:
comfyanonymous
2024-03-20 01:29:26 -04:00
parent 150a3e946f
commit c18a203a8a
2 changed files with 76 additions and 28 deletions

View File

@@ -2,6 +2,7 @@ import torch
import copy
import inspect
import logging
import uuid
import comfy.utils
import comfy.model_management
@@ -25,6 +26,7 @@ class ModelPatcher:
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.patches_uuid = uuid.uuid4()
def model_size(self):
if self.size > 0:
@@ -39,10 +41,13 @@ class ModelPatcher:
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.model_keys = self.model_keys
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
return n
def is_clone(self, other):
@@ -50,6 +55,19 @@ class ModelPatcher:
return True
return False
def clone_has_same_weights(self, clone):
if not self.is_clone(clone):
return False
if len(self.patches) == 0 and len(clone.patches) == 0:
return True
if self.patches_uuid == clone.patches_uuid:
if len(self.patches) != len(clone.patches):
logging.warning("WARNING: something went wrong, same patch uuid but different length of patches.")
else:
return True
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
@@ -154,6 +172,7 @@ class ModelPatcher:
current_patches.append((strength_patch, patches[k], strength_model))
self.patches[k] = current_patches
self.patches_uuid = uuid.uuid4()
return list(p)
def get_key_patches(self, filter_prefix=None):
@@ -387,31 +406,32 @@ class ModelPatcher:
return weight
def unpatch_model(self, device_to=None):
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self.model_lowvram:
for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
self.model_lowvram = False
self.model_lowvram = False
keys = list(self.backup.keys())
keys = list(self.backup.keys())
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
if self.weight_inplace_update:
for k in keys:
comfy.utils.copy_to_param(self.model, k, self.backup[k])
else:
for k in keys:
comfy.utils.set_attr_param(self.model, k, self.backup[k])
self.backup = {}
self.backup.clear()
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
if device_to is not None:
self.model.to(device_to)
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys: