Speed up lora loading a bit.

This commit is contained in:
comfyanonymous
2023-07-15 13:24:05 -04:00
parent 50b1180dde
commit 490771b7f4
3 changed files with 35 additions and 25 deletions

View File

@@ -258,15 +258,11 @@ def load_model_gpu(model):
if model is current_loaded_model:
return
unload_model()
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
raise e
torch_dev = model.load_device
model.model_patches_to(torch_dev)
model.model_patches_to(model.model_dtype())
current_loaded_model = model
if is_device_cpu(torch_dev):
vram_set_state = VRAMState.DISABLED
@@ -280,8 +276,7 @@ def load_model_gpu(model):
if model_size > (current_free_mem - minimum_inference_memory()): #only switch to lowvram if really necessary
vram_set_state = VRAMState.LOW_VRAM
current_loaded_model = model
real_model = model.model
if vram_set_state == VRAMState.DISABLED:
pass
elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED:
@@ -295,6 +290,14 @@ def load_model_gpu(model):
accelerate.dispatch_model(real_model, device_map=device_map, main_device=torch_dev)
model_accelerated = True
try:
real_model = model.patch_model()
except Exception as e:
model.unpatch_model()
unload_model()
raise e
return current_loaded_model
def load_controlnet_gpu(control_models):