Fix issue with lowvram mode breaking model saving.

This commit is contained in:
comfyanonymous
2024-05-11 21:46:05 -04:00
parent 4f63ee99f1
commit e1489ad257
4 changed files with 15 additions and 9 deletions

View File

@@ -285,7 +285,7 @@ class LoadedModel:
else:
return self.model_memory()
def model_load(self, lowvram_model_memory=0):
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
patch_model_to = self.device
self.model.model_patches_to(self.device)
@@ -295,7 +295,7 @@ class LoadedModel:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
@@ -379,7 +379,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
def load_models_gpu(models, memory_required=0):
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
inference_memory = minimum_inference_memory()
@@ -444,7 +444,7 @@ def load_models_gpu(models, memory_required=0):
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model)
return