From a02e6d5acb129072438b620a47080e0998215317 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=B3n=C3=A1n=20Carrigan?= Date: Sun, 6 Feb 2022 21:18:33 +0000 Subject: [PATCH] feat: async build_spec --- lua/neotest-python/base.lua | 61 ++++++++++++++++++++++--------------- lua/neotest-python/init.lua | 40 ++++++++++++++++-------- neotest_python/pytest.py | 11 ------- neotest_python/unittest.py | 8 ----- 4 files changed, 64 insertions(+), 56 deletions(-) diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index 8af9c53..c2e11cd 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -1,40 +1,49 @@ +local async = require("plenary.async") local lib = require("neotest.lib") local Path = require("plenary.path") local M = {} function M.is_test_file(file_path) - if not vim.endswith(file_path, ".py") then - return false - end - local elems = vim.split(file_path, Path.path.sep) - local file_name = elems[#elems] - return vim.startswith(file_name, "test_") + if not vim.endswith(file_path, ".py") then + return false + end + local elems = vim.split(file_path, Path.path.sep) + local file_name = elems[#elems] + return vim.startswith(file_name, "test_") +end + +M.module_exists = function(module, python_command) + return lib.process.run(vim.tbl_flatten({ + python_command, + "-c", + "import imp; imp.find_module('"..module.."')", + })) == 0 end ---@return string[] function M.get_python_command(root) - -- Use activated virtualenv. - if vim.env.VIRTUAL_ENV then - return { Path:new(vim.env.VIRTUAL_ENV, "bin", "python").filename } - end + -- Use activated virtualenv. + if vim.env.VIRTUAL_ENV then + return { Path:new(vim.env.VIRTUAL_ENV, "bin", "python").filename } + end - for _, pattern in ipairs({ "*", ".*" }) do - local match = vim.fn.glob(Path:new(root or vim.fn.getcwd(), pattern, "pyvenv.cfg").filename) - if match ~= "" then - return { (Path:new(match):parent() / "bin" / "python").filename } - end - end + for _, pattern in ipairs({ "*", ".*" }) do + local match = async.fn.glob(Path:new(root or async.fn.getcwd(), pattern, "pyvenv.cfg").filename) + if match ~= "" then + return { (Path:new(match):parent() / "bin" / "python").filename } + end + end - if lib.files.exists("Pipfile") then - return { "pipenv", "run", "python" } - end - -- Fallback to system Python. - return { vim.fn.exepath("python3") or vim.fn.exepath("python") or "python" } + if lib.files.exists("Pipfile") then + return { "pipenv", "run", "python" } + end + -- Fallback to system Python. + return { async.fn.exepath("python3") or async.fn.exepath("python") or "python" } end function M.parse_positions(file_path) - local query = [[ + local query = [[ ((function_definition name: (identifier) @test.name) (#match? @test.name "^test_")) @@ -44,18 +53,19 @@ function M.parse_positions(file_path) name: (identifier) @namespace.name) @namespace.definition ]] - return lib.treesitter.parse_positions(file_path, query) + return lib.treesitter.parse_positions(file_path, query) end -function M.get_strategy_config(strategy, python_script, args) +function M.get_strategy_config(strategy, python, python_script, args) local config = { dap = function() return { type = "python", name = "Neotest Debugger", request = "launch", + python = python, program = python_script, - cwd = vim.fn.getcwd(), + cwd = async.fn.getcwd(), args = args, justMyCode= false, } @@ -67,3 +77,4 @@ function M.get_strategy_config(strategy, python_script, args) end return M + diff --git a/lua/neotest-python/init.lua b/lua/neotest-python/init.lua index fd729ca..1c79eea 100644 --- a/lua/neotest-python/init.lua +++ b/lua/neotest-python/init.lua @@ -1,4 +1,4 @@ -local logger = require("neotest.logging") +local async = require("plenary.async") local Path = require("plenary.path") local lib = require("neotest.lib") local base = require("neotest-python.base") @@ -17,7 +17,13 @@ local get_args = function(runner, position) return lib.vim_test.collect_args("python", runner, position) end -local get_runner = function() +local stored_runners = {} + +local get_runner = function(python_command) + local command_str = table.concat(python_command, " ") + if stored_runners[command_str] then + return stored_runners[command_str] + end local vim_test_runner = vim.g["test#python#runner"] if vim_test_runner == "pyunit" then return "unittest" @@ -25,16 +31,21 @@ local get_runner = function() if vim_test_runner and lib.func_util.index({ "unittest", "pytest" }, vim_test_runner) then return vim_test_runner end - if vim.fn.executable("pytest") == 1 then - return "pytest" - end - return "unittest" + local runner = base.module_exists("pytest", python_command) and "pytest" or "unittest" + stored_runners[command_str] = runner + return runner end ---@type NeotestAdapter local PythonNeotestAdapter = { name = "neotest-python" } -PythonNeotestAdapter.root = lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py") +PythonNeotestAdapter.root = lib.files.match_root_pattern( + "pyproject.toml", + "setup.cfg", + "mypy.ini", + "pytest.ini", + "setup.py" +) function PythonNeotestAdapter.is_test_file(file_path) return base.is_test_file(file_path) @@ -53,18 +64,23 @@ function PythonNeotestAdapter.discover_positions(path) name: (identifier) @namespace.name) @namespace.definition ]] + local root = PythonNeotestAdapter.root(path) + local python = base.get_python_command(root) + local runner = get_runner(python) return lib.treesitter.parse_positions(path, query, { - require_namespaces = get_runner() == "unittest", + require_namespaces = runner == "unittest", }) end +---@async ---@param args NeotestRunArgs ---@return NeotestRunSpec function PythonNeotestAdapter.build_spec(args) local position = args.tree:data() - local results_path = vim.fn.tempname() - local runner = get_runner() - local python = base.get_python_command(vim.fn.getcwd()) + local results_path = async.fn.tempname() + local root = PythonNeotestAdapter.root(position.path) + local python = base.get_python_command(root) + local runner = get_runner(python) local script_args = vim.tbl_flatten({ "--results-file", results_path, @@ -86,7 +102,7 @@ function PythonNeotestAdapter.build_spec(args) context = { results_path = results_path, }, - strategy = base.get_strategy_config(args.strategy, python_script, script_args), + strategy = base.get_strategy_config(args.strategy, python, python_script, script_args), } end diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index 174498f..67aa64b 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -76,19 +76,8 @@ class PytestNeotestAdapter(NeotestAdapter): "errors": errors, }, ) - results[abs_path] = self.update_result( - results.get(abs_path), - { - "short": None, - "status": NeotestResultStatus(report.outcome), - "errors": errors, - }, - ) import pytest pytest.main(args=args, plugins=[NeotestResultCollector]) return results - - def update_report(self, report: Optional[Dict], update: Dict): - ... diff --git a/neotest_python/unittest.py b/neotest_python/unittest.py index ad8021e..b10082d 100644 --- a/neotest_python/unittest.py +++ b/neotest_python/unittest.py @@ -76,14 +76,6 @@ class UnittestNeotestAdapter(NeotestAdapter): "short": None, }, ) - results[case_file] = self.update_result( - results.get(case_file), - { - "status": NeotestResultStatus.FAILED, - "errors": [{"message": message, "line": error_line}], - "short": None, - }, - ) for case, message in result.skipped: results[self.case_id(case)] = self.update_result( results[self.case_id(case)],