System stats endpoint

This commit is contained in:
space-nuko
2023-06-01 23:26:23 -05:00
parent 1bbd3f7fe1
commit b5dd15c67a
2 changed files with 51 additions and 0 deletions

View File

@@ -308,6 +308,33 @@ def pytorch_attention_flash_attention():
return True
return False
def get_total_memory(dev=None, torch_total_too=False):
global xpu_available
global directml_enabled
if dev is None:
dev = get_torch_device()
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
else:
if directml_enabled:
mem_total = 1024 * 1024 * 1024 #TODO
mem_total_torch = mem_total
elif xpu_available:
mem_total = torch.xpu.get_device_properties(dev).total_memory
mem_total_torch = mem_total
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_cuda = torch.cuda.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_cuda + mem_total_torch
if torch_total_too:
return (mem_total, mem_total_torch)
else:
return mem_total
def get_free_memory(dev=None, torch_free_too=False):
global xpu_available
global directml_enabled