seperates out arg parser and imports args

This commit is contained in:
EllangoK
2023-04-05 23:41:23 -04:00
parent dd29966f8a
commit e5e587b1c0
4 changed files with 88 additions and 84 deletions

View File

@@ -1,36 +1,35 @@
import psutil
from enum import Enum
from cli_args import args
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
class VRAMState(Enum):
CPU = 0
NO_VRAM = 1
LOW_VRAM = 2
NORMAL_VRAM = 3
HIGH_VRAM = 4
MPS = 5
accelerate_enabled = False
vram_state = NORMAL_VRAM
# Determine VRAM State
vram_state = VRAMState.NORMAL_VRAM
set_vram_to = VRAMState.NORMAL_VRAM
total_vram = 0
total_vram_available_mb = -1
import sys
import psutil
forced_cpu = "--cpu" in sys.argv
set_vram_to = NORMAL_VRAM
accelerate_enabled = False
try:
import torch
total_vram = torch.cuda.mem_get_info(torch.cuda.current_device())[1] / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
forced_normal_vram = "--normalvram" in sys.argv
if not forced_normal_vram and not forced_cpu:
if not args.normalvram and not args.cpu:
if total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = LOW_VRAM
set_vram_to = VRAMState.LOW_VRAM
elif total_vram > total_ram * 1.1 and total_vram > 14336:
print("Enabling highvram mode because your GPU has more vram than your computer has ram. If you don't want this use: --normalvram")
vram_state = HIGH_VRAM
vram_state = VRAMState.HIGH_VRAM
except:
pass
@@ -39,34 +38,32 @@ try:
except:
OOM_EXCEPTION = Exception
if "--disable-xformers" in sys.argv:
XFORMERS_IS_AVAILBLE = False
if args.disable_xformers:
XFORMERS_IS_AVAILABLE = False
else:
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILABLE = False
ENABLE_PYTORCH_ATTENTION = False
if "--use-pytorch-cross-attention" in sys.argv:
ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILBLE = False
XFORMERS_IS_AVAILABLE = False
if args.lowvram:
set_vram_to = VRAMState.LOW_VRAM
elif args.novram:
set_vram_to = VRAMState.NO_VRAM
elif args.highvram:
vram_state = VRAMState.HIGH_VRAM
if "--lowvram" in sys.argv:
set_vram_to = LOW_VRAM
if "--novram" in sys.argv:
set_vram_to = NO_VRAM
if "--highvram" in sys.argv:
vram_state = HIGH_VRAM
if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
try:
import accelerate
accelerate_enabled = True
@@ -81,14 +78,14 @@ if set_vram_to == LOW_VRAM or set_vram_to == NO_VRAM:
try:
if torch.backends.mps.is_available():
vram_state = MPS
vram_state = VRAMState.MPS
except:
pass
if forced_cpu:
vram_state = CPU
if args.cpu:
vram_state = VRAMState.CPU
print("Set vram state to:", ["CPU", "NO VRAM", "LOW VRAM", "NORMAL VRAM", "HIGH VRAM", "MPS"][vram_state])
print(f"Set vram state to: {vram_state.name}")
current_loaded_model = None
@@ -109,12 +106,12 @@ def unload_model():
model_accelerated = False
#never unload models from GPU on high vram
if vram_state != HIGH_VRAM:
if vram_state != VRAMState.HIGH_VRAM:
current_loaded_model.model.cpu()
current_loaded_model.unpatch_model()
current_loaded_model = None
if vram_state != HIGH_VRAM:
if vram_state != VRAMState.HIGH_VRAM:
if len(current_gpu_controlnets) > 0:
for n in current_gpu_controlnets:
n.cpu()
@@ -135,19 +132,19 @@ def load_model_gpu(model):
model.unpatch_model()
raise e
current_loaded_model = model
if vram_state == CPU:
if vram_state == VRAMState.CPU:
pass
elif vram_state == MPS:
elif vram_state == VRAMState.MPS:
mps_device = torch.device("mps")
real_model.to(mps_device)
pass
elif vram_state == NORMAL_VRAM or vram_state == HIGH_VRAM:
elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM:
model_accelerated = False
real_model.cuda()
else:
if vram_state == NO_VRAM:
if vram_state == VRAMState.NO_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"})
elif vram_state == LOW_VRAM:
elif vram_state == VRAMState.LOW_VRAM:
device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"})
accelerate.dispatch_model(real_model, device_map=device_map, main_device="cuda")
@@ -157,10 +154,10 @@ def load_model_gpu(model):
def load_controlnet_gpu(models):
global current_gpu_controlnets
global vram_state
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return
@@ -176,20 +173,20 @@ def load_controlnet_gpu(models):
def load_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cuda()
return model
def unload_if_low_vram(model):
global vram_state
if vram_state == LOW_VRAM or vram_state == NO_VRAM:
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
return model.cpu()
return model
def get_torch_device():
if vram_state == MPS:
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return torch.device("cpu")
else:
return torch.cuda.current_device()
@@ -201,9 +198,9 @@ def get_autocast_device(dev):
def xformers_enabled():
if vram_state == CPU:
if vram_state == VRAMState.CPU:
return False
return XFORMERS_IS_AVAILBLE
return XFORMERS_IS_AVAILABLE
def xformers_enabled_vae():
@@ -243,7 +240,7 @@ def get_free_memory(dev=None, torch_free_too=False):
def maximum_batch_area():
global vram_state
if vram_state == NO_VRAM:
if vram_state == VRAMState.NO_VRAM:
return 0
memory_free = get_free_memory() / (1024 * 1024)
@@ -252,11 +249,11 @@ def maximum_batch_area():
def cpu_mode():
global vram_state
return vram_state == CPU
return vram_state == VRAMState.CPU
def mps_mode():
global vram_state
return vram_state == MPS
return vram_state == VRAMState.MPS
def should_use_fp16():
if cpu_mode() or mps_mode():