diff --git a/comfy/model_management.py b/comfy/model_management.py index 21f7c7186..92c8ac842 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -233,10 +233,9 @@ def unload_model(): accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model) model_accelerated = False - + current_loaded_model.unpatch_model() current_loaded_model.model.to(current_loaded_model.offload_device) current_loaded_model.model_patches_to(current_loaded_model.offload_device) - current_loaded_model.unpatch_model() current_loaded_model = None if vram_state != VRAMState.HIGH_VRAM: soft_empty_cache() @@ -282,14 +281,6 @@ def load_model_gpu(model): elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(torch_dev) - else: - if vram_set_state == VRAMState.NO_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) - elif vram_set_state == VRAMState.LOW_VRAM: - device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) - - accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) - model_accelerated = True try: real_model = model.patch_model() @@ -298,6 +289,15 @@ def load_model_gpu(model): unload_model() raise e + if vram_set_state == VRAMState.NO_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) + accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) + model_accelerated = True + elif vram_set_state == VRAMState.LOW_VRAM: + device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(lowvram_model_memory // (1024 * 1024)), "cpu": "16GiB"}) + accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev) + model_accelerated = True + return current_loaded_model def load_controlnet_gpu(control_models): diff --git a/comfy/sd.py b/comfy/sd.py index e30fae16c..9a96dbe8c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,11 +428,17 @@ class ModelPatcher: return weight def unpatch_model(self): - model_sd = self.model_state_dict() keys = list(self.backup.keys()) + def set_attr(obj, attr, value): + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + setattr(obj, attrs[-1], torch.nn.Parameter(value)) + del prev + for k in keys: - model_sd[k][:] = self.backup[k] - del self.backup[k] + set_attr(self.model, k, self.backup[k]) self.backup = {}