Support ascend npu (#5436)

* support ascend npu

Co-authored-by: YukMingLaw <lymmm2@163.com>
Co-authored-by: starmountain1997 <guozr1997@hotmail.com>
Co-authored-by: Ginray <ginray0215@gmail.com>
This commit is contained in:
Huazhong Ji
2024-12-27 08:36:50 +08:00
committed by GitHub
parent ee9547ba31
commit c4bfdba330
2 changed files with 50 additions and 1 deletions

View File

@@ -86,6 +86,13 @@ try:
except:
pass
try:
import torch_npu
_ = torch.npu.device_count()
npu_available = torch.npu.is_available()
except:
npu_available = False
if args.cpu:
cpu_state = CPUState.CPU
@@ -97,6 +104,12 @@ def is_intel_xpu():
return True
return False
def is_ascend_npu():
global npu_available
if npu_available:
return True
return False
def get_torch_device():
global directml_enabled
global cpu_state
@@ -110,6 +123,8 @@ def get_torch_device():
else:
if is_intel_xpu():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
else:
return torch.device(torch.cuda.current_device())
@@ -130,6 +145,12 @@ def get_total_memory(dev=None, torch_total_too=False):
mem_reserved = stats['reserved_bytes.all.current']
mem_total_torch = mem_reserved
mem_total = torch.xpu.get_device_properties(dev).total_memory
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -209,7 +230,7 @@ try:
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
if is_intel_xpu():
if is_intel_xpu() or is_ascend_npu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
@@ -274,6 +295,8 @@ def get_torch_device_name(device):
return "{}".format(device.type)
elif is_intel_xpu():
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@@ -860,6 +883,8 @@ def xformers_enabled():
return False
if is_intel_xpu():
return False
if is_ascend_npu():
return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -884,6 +909,8 @@ def pytorch_attention_flash_attention():
return True
if is_intel_xpu():
return True
if is_ascend_npu():
return True
return False
def mac_version():
@@ -923,6 +950,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_torch = mem_reserved - mem_active
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
mem_free_total = mem_free_xpu + mem_free_torch
elif is_ascend_npu():
stats = torch.npu.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -984,6 +1018,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
if is_ascend_npu():
return True
if torch.version.hip:
return True
@@ -1081,6 +1118,8 @@ def soft_empty_cache(force=False):
torch.mps.empty_cache()
elif is_intel_xpu():
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
elif torch.cuda.is_available():
if force or is_nvidia(): #This seems to make things worse on ROCm so I only do it for cuda
torch.cuda.empty_cache()