Speed up lora loading a bit.

This commit is contained in:
comfyanonymous
2023-07-15 13:24:05 -04:00
parent 50b1180dde
commit 490771b7f4
3 changed files with 35 additions and 25 deletions

View File

@@ -4,18 +4,20 @@ import struct
import comfy.checkpoint_pickle
import safetensors.torch
def load_torch_file(ckpt, safe_load=False):
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device="cpu")
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
if not 'weights_only' in torch.load.__code__.co_varnames:
print("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
safe_load = False
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu", pickle_module=comfy.checkpoint_pickle)
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: