diff --git a/comfy/model_management.py b/comfy/model_management.py index 9c3147d76..a492ca6b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -4,16 +4,22 @@ from comfy.cli_args import args import torch class VRAMState(Enum): - CPU = 0 + DISABLED = 0 NO_VRAM = 1 LOW_VRAM = 2 NORMAL_VRAM = 3 HIGH_VRAM = 4 - MPS = 5 + SHARED = 5 + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 # Determine VRAM State vram_state = VRAMState.NORMAL_VRAM set_vram_to = VRAMState.NORMAL_VRAM +cpu_state = CPUState.GPU total_vram = 0 @@ -40,15 +46,25 @@ try: except: pass +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS +except: + pass + +if args.cpu: + cpu_state = CPUState.CPU + def get_torch_device(): global xpu_available global directml_enabled + global cpu_state if directml_enabled: global directml_device return directml_device - if vram_state == VRAMState.MPS: + if cpu_state == CPUState.MPS: return torch.device("mps") - if vram_state == VRAMState.CPU: + if cpu_state == CPUState.CPU: return torch.device("cpu") else: if xpu_available: @@ -143,8 +159,6 @@ if args.force_fp32: print("Forcing FP32, if this improves things please report it.") FORCE_FP32 = True - - if lowvram_available: try: import accelerate @@ -157,17 +171,15 @@ if lowvram_available: lowvram_available = False -try: - if torch.backends.mps.is_available(): - vram_state = VRAMState.MPS -except: - pass +if cpu_state != CPUState.GPU: + vram_state = VRAMState.DISABLED -if args.cpu: - vram_state = VRAMState.CPU +if cpu_state == CPUState.MPS: + vram_state = VRAMState.SHARED print(f"Set vram state to: {vram_state.name}") + def get_torch_device_name(device): if hasattr(device, 'type'): if device.type == "cuda": @@ -241,13 +253,9 @@ def load_model_gpu(model): current_loaded_model = model - if vram_set_state == VRAMState.CPU: + if vram_set_state == VRAMState.DISABLED: pass - elif vram_set_state == VRAMState.MPS: - mps_device = torch.device("mps") - real_model.to(mps_device) - pass - elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM: + elif vram_set_state == VRAMState.NORMAL_VRAM or vram_set_state == VRAMState.HIGH_VRAM or vram_set_state == VRAMState.SHARED: model_accelerated = False real_model.to(get_torch_device()) else: @@ -263,7 +271,7 @@ def load_model_gpu(model): def load_controlnet_gpu(control_models): global current_gpu_controlnets global vram_state - if vram_state == VRAMState.CPU: + if vram_state == VRAMState.DISABLED: return if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: @@ -308,7 +316,8 @@ def get_autocast_device(dev): def xformers_enabled(): global xpu_available global directml_enabled - if vram_state == VRAMState.CPU: + global cpu_state + if cpu_state != CPUState.GPU: return False if xpu_available: return False @@ -380,12 +389,12 @@ def maximum_batch_area(): return int(max(area, 0)) def cpu_mode(): - global vram_state - return vram_state == VRAMState.CPU + global cpu_state + return cpu_state == CPUState.CPU def mps_mode(): - global vram_state - return vram_state == VRAMState.MPS + global cpu_state + return cpu_state == CPUState.MPS def should_use_fp16(): global xpu_available @@ -417,8 +426,8 @@ def should_use_fp16(): def soft_empty_cache(): global xpu_available - global vram_state - if vram_state == VRAMState.MPS: + global cpu_state + if cpu_state == CPUState.MPS: torch.mps.empty_cache() elif xpu_available: torch.xpu.empty_cache()