Fix lowvram model merging.

This commit is contained in:
comfyanonymous
2023-08-26 11:52:07 -04:00
parent f72780a7e3
commit a57b0c797b
3 changed files with 15 additions and 7 deletions

View File

@@ -1,6 +1,7 @@
import psutil
from enum import Enum
from comfy.cli_args import args
import comfy.utils
import torch
import sys
@@ -637,6 +638,13 @@ def soft_empty_cache():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def resolve_lowvram_weight(weight, model, key):
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
key_split = key.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
op = comfy.utils.get_attr(model, '.'.join(key_split[:-1]))
weight = op._hf_hook.weights_map[key_split[-1]]
return weight
#TODO: might be cleaner to put this somewhere else
import threading