Replace prints with logging and add --verbose argument.

This commit is contained in:
comfyanonymous
2024-03-10 11:37:08 -04:00
parent 4656273e72
commit 65397ce601
12 changed files with 90 additions and 65 deletions

View File

@@ -1,4 +1,5 @@
import psutil
import logging
from enum import Enum
from comfy.cli_args import args
import comfy.utils
@@ -29,7 +30,7 @@ lowvram_available = True
xpu_available = False
if args.deterministic:
print("Using deterministic algorithms for pytorch")
logging.warning("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
directml_enabled = False
@@ -41,7 +42,7 @@ if args.directml is not None:
directml_device = torch_directml.device()
else:
directml_device = torch_directml.device(device_index)
print("Using directml with device:", torch_directml.device_name(device_index))
logging.warning("Using directml with device: {}".format(torch_directml.device_name(device_index)))
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
@@ -117,10 +118,10 @@ def get_total_memory(dev=None, torch_total_too=False):
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
print("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
logging.warning("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
if not args.normalvram and not args.cpu:
if lowvram_available and total_vram <= 4096:
print("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
logging.warning("Trying to enable lowvram mode because your GPU seems to have 4GB or less. If you don't want this use: --normalvram")
set_vram_to = VRAMState.LOW_VRAM
try:
@@ -143,12 +144,10 @@ else:
pass
try:
XFORMERS_VERSION = xformers.version.__version__
print("xformers version:", XFORMERS_VERSION)
logging.warning("xformers version: {}".format(XFORMERS_VERSION))
if XFORMERS_VERSION.startswith("0.0.18"):
print()
print("WARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
print("Please downgrade or upgrade xformers to a different version.")
print()
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
XFORMERS_ENABLED_VAE = False
except:
pass
@@ -213,11 +212,11 @@ elif args.highvram or args.gpu_only:
FORCE_FP32 = False
FORCE_FP16 = False
if args.force_fp32:
print("Forcing FP32, if this improves things please report it.")
logging.warning("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
if args.force_fp16:
print("Forcing FP16.")
logging.warning("Forcing FP16.")
FORCE_FP16 = True
if lowvram_available:
@@ -231,12 +230,12 @@ if cpu_state != CPUState.GPU:
if cpu_state == CPUState.MPS:
vram_state = VRAMState.SHARED
print(f"Set vram state to: {vram_state.name}")
logging.warning(f"Set vram state to: {vram_state.name}")
DISABLE_SMART_MEMORY = args.disable_smart_memory
if DISABLE_SMART_MEMORY:
print("Disabling smart memory management")
logging.warning("Disabling smart memory management")
def get_torch_device_name(device):
if hasattr(device, 'type'):
@@ -254,11 +253,11 @@ def get_torch_device_name(device):
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
try:
print("Device:", get_torch_device_name(get_torch_device()))
logging.warning("Device: {}".format(get_torch_device_name(get_torch_device())))
except:
print("Could not pick default device.")
logging.warning("Could not pick default device.")
print("VAE dtype:", VAE_DTYPE)
logging.warning("VAE dtype: {}".format(VAE_DTYPE))
current_loaded_models = []
@@ -301,7 +300,7 @@ class LoadedModel:
raise e
if lowvram_model_memory > 0:
print("loading in lowvram mode", lowvram_model_memory/(1024 * 1024))
logging.warning("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
mem_counter = 0
for m in self.real_model.modules():
if hasattr(m, "comfy_cast_weights"):
@@ -314,7 +313,7 @@ class LoadedModel:
elif hasattr(m, "weight"): #only modules with comfy_cast_weights can be set to lowvram mode
m.to(self.device)
mem_counter += module_size(m)
print("lowvram: loaded module regularly", m)
logging.warning("lowvram: loaded module regularly {}".format(m))
self.model_accelerated = True
@@ -348,7 +347,7 @@ def unload_model_clones(model):
to_unload = [i] + to_unload
for i in to_unload:
print("unload clone", i)
logging.warning("unload clone {}".format(i))
current_loaded_models.pop(i).model_unload()
def free_memory(memory_required, device, keep_loaded=[]):
@@ -390,7 +389,7 @@ def load_models_gpu(models, memory_required=0):
models_already_loaded.append(loaded_model)
else:
if hasattr(x, "model"):
print(f"Requested to load {x.model.__class__.__name__}")
logging.warning(f"Requested to load {x.model.__class__.__name__}")
models_to_load.append(loaded_model)
if len(models_to_load) == 0:
@@ -400,7 +399,7 @@ def load_models_gpu(models, memory_required=0):
free_memory(extra_mem, d, models_already_loaded)
return
print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
logging.warning(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
total_memory_required = {}
for loaded_model in models_to_load: