Basic torch_directml support. Use --directml to use it.

This commit is contained in:
comfyanonymous
2023-04-28 14:28:57 -04:00
parent ab9a9deff4
commit 3baded9892
2 changed files with 27 additions and 1 deletions

View File

@@ -20,6 +20,13 @@ total_vram_available_mb = -1
accelerate_enabled = False
xpu_available = False
directml_enabled = False
if args.directml:
import torch_directml
print("Using directml")
directml_enabled = True
# torch_directml.disable_tiled_resources(True)
try:
import torch
try:
@@ -217,6 +224,9 @@ def unload_if_low_vram(model):
def get_torch_device():
global xpu_available
global directml_enabled
if directml_enabled:
return torch_directml.device()
if vram_state == VRAMState.MPS:
return torch.device("mps")
if vram_state == VRAMState.CPU:
@@ -234,8 +244,14 @@ def get_autocast_device(dev):
def xformers_enabled():
global xpu_available
global directml_enabled
if vram_state == VRAMState.CPU:
return False
if xpu_available:
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -251,6 +267,7 @@ def pytorch_attention_enabled():
def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
@@ -258,7 +275,10 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_total = psutil.virtual_memory().available
mem_free_torch = mem_free_total
else:
if xpu_available:
if directml_enabled:
mem_free_total = 1024 * 1024 * 1024 #TODO
mem_free_torch = mem_free_total
elif xpu_available:
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
mem_free_torch = mem_free_total
else:
@@ -293,9 +313,14 @@ def mps_mode():
def should_use_fp16():
global xpu_available
global directml_enabled
if FORCE_FP32:
return False
if directml_enabled:
return False
if cpu_mode() or mps_mode() or xpu_available:
return False #TODO ?