Implement Linear hypernetworks.

Add a HypernetworkLoader node to use hypernetworks.
This commit is contained in:
comfyanonymous
2023-04-23 12:35:25 -04:00
parent 6908f9c949
commit 5282f56434
9 changed files with 185 additions and 16 deletions

View File

@@ -1,11 +1,14 @@
import torch
def load_torch_file(ckpt):
def load_torch_file(ckpt, safe_load=False):
if ckpt.lower().endswith(".safetensors"):
import safetensors.torch
sd = safetensors.torch.load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if safe_load:
pl_sd = torch.load(ckpt, map_location="cpu", weights_only=True)
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: