Add ControlNet support.

This commit is contained in:
comfyanonymous
2023-02-16 10:38:08 -05:00
parent bc69fb5245
commit 4efa67fa12
9 changed files with 580 additions and 63 deletions

View File

@@ -48,7 +48,7 @@ print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM"][vram_s
current_loaded_model = None
current_gpu_controlnets = []
model_accelerated = False
@@ -56,6 +56,7 @@ model_accelerated = False
def unload_model():
global current_loaded_model
global model_accelerated
global current_gpu_controlnets
if current_loaded_model is not None:
if model_accelerated:
accelerate.hooks.remove_hook_from_submodules(current_loaded_model.model)
@@ -64,6 +65,10 @@ def unload_model():
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
current_gpu_controlnets = []
def load_model_gpu(model):
@@ -95,6 +100,16 @@ def load_model_gpu(model):
model_accelerated = True
return current_loaded_model
def load_controlnet_gpu(models):
global current_gpu_controlnets
for m in current_gpu_controlnets:
if m not in models:
m.cpu()
current_gpu_controlnets = []
for m in models:
current_gpu_controlnets.append(m.cuda())
def get_free_memory():
dev = torch.cuda.current_device()