From bceccca0e59862c3410b5d99b47fe1e01ba914af Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Thu, 6 Apr 2023 23:52:34 -0400 Subject: [PATCH] Small refactor. --- comfy/model_management.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 92c59efe7..504da2190 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -129,7 +129,6 @@ def load_model_gpu(model): global current_loaded_model global vram_state global model_accelerated - global xpu_available if model is current_loaded_model: return @@ -148,17 +147,14 @@ def load_model_gpu(model): pass elif vram_state == VRAMState.NORMAL_VRAM or vram_state == VRAMState.HIGH_VRAM: model_accelerated = False - if xpu_available: - real_model.to("xpu") - else: - real_model.cuda() + real_model.to(get_torch_device()) else: if vram_state == VRAMState.NO_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "256MiB", "cpu": "16GiB"}) elif vram_state == VRAMState.LOW_VRAM: device_map = accelerate.infer_auto_device_map(real_model, max_memory={0: "{}MiB".format(total_vram_available_mb), "cpu": "16GiB"}) - accelerate.dispatch_model(real_model, device_map=device_map, main_device="xpu" if xpu_available else "cuda") + accelerate.dispatch_model(real_model, device_map=device_map, main_device=get_torch_device()) model_accelerated = True return current_loaded_model @@ -184,12 +180,8 @@ def load_controlnet_gpu(models): def load_if_low_vram(model): global vram_state - global xpu_available if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM: - if xpu_available: - return model.to("xpu") - else: - return model.cuda() + return model.to(get_torch_device()) return model def unload_if_low_vram(model):