feat(pytest): populate parameterized test instances (#36)
Co-authored-by: Rónán Carrigan <rcarriga@tcd.ie>
This commit is contained in:
@@ -38,7 +38,9 @@ require("neotest").setup({
|
|||||||
is_test_file = function(file_path)
|
is_test_file = function(file_path)
|
||||||
...
|
...
|
||||||
end,
|
end,
|
||||||
|
-- !!EXPERIMENTAL!! Enable shelling out to `pytest` to discover test
|
||||||
|
-- instances for files containing a parametrize mark (default: false)
|
||||||
|
pytest_discover_instances = true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
local async = require("neotest.async")
|
local async = require("neotest.async")
|
||||||
local lib = require("neotest.lib")
|
local lib = require("neotest.lib")
|
||||||
local base = require("neotest-python.base")
|
local base = require("neotest-python.base")
|
||||||
|
local pytest = require("neotest-python.pytest")
|
||||||
|
|
||||||
local function get_script()
|
local function get_script()
|
||||||
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
|
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
|
||||||
@@ -15,6 +16,7 @@ end
|
|||||||
|
|
||||||
local dap_args
|
local dap_args
|
||||||
local is_test_file = base.is_test_file
|
local is_test_file = base.is_test_file
|
||||||
|
local pytest_discover_instances = false
|
||||||
|
|
||||||
local function get_strategy_config(strategy, python, program, args)
|
local function get_strategy_config(strategy, python, program, args)
|
||||||
local config = {
|
local config = {
|
||||||
@@ -80,8 +82,13 @@ function PythonNeotestAdapter.filter_dir(name)
|
|||||||
end
|
end
|
||||||
|
|
||||||
---@async
|
---@async
|
||||||
---@return Tree | nil
|
---@return neotest.Tree | nil
|
||||||
function PythonNeotestAdapter.discover_positions(path)
|
function PythonNeotestAdapter.discover_positions(path)
|
||||||
|
local root = PythonNeotestAdapter.root(path) or vim.loop.cwd()
|
||||||
|
local python = get_python(root)
|
||||||
|
local runner = get_runner(python)
|
||||||
|
|
||||||
|
-- Parse the file while pytest is running
|
||||||
local query = [[
|
local query = [[
|
||||||
;; Match undecorated functions
|
;; Match undecorated functions
|
||||||
((function_definition
|
((function_definition
|
||||||
@@ -109,12 +116,15 @@ function PythonNeotestAdapter.discover_positions(path)
|
|||||||
(#not-has-parent? @namespace.definition decorated_definition)
|
(#not-has-parent? @namespace.definition decorated_definition)
|
||||||
)
|
)
|
||||||
]]
|
]]
|
||||||
local root = PythonNeotestAdapter.root(path)
|
local positions = lib.treesitter.parse_positions(path, query, {
|
||||||
local python = get_python(root)
|
|
||||||
local runner = get_runner(python)
|
|
||||||
return lib.treesitter.parse_positions(path, query, {
|
|
||||||
require_namespaces = runner == "unittest",
|
require_namespaces = runner == "unittest",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if runner == "pytest" and pytest_discover_instances then
|
||||||
|
pytest.augment_positions(python, get_script(), path, positions, root)
|
||||||
|
end
|
||||||
|
|
||||||
|
return positions
|
||||||
end
|
end
|
||||||
|
|
||||||
---@async
|
---@async
|
||||||
@@ -232,6 +242,9 @@ setmetatable(PythonNeotestAdapter, {
|
|||||||
if type(opts.dap) == "table" then
|
if type(opts.dap) == "table" then
|
||||||
dap_args = opts.dap
|
dap_args = opts.dap
|
||||||
end
|
end
|
||||||
|
if opts.pytest_discover_instances ~= nil then
|
||||||
|
pytest_discover_instances = opts.pytest_discover_instances
|
||||||
|
end
|
||||||
return PythonNeotestAdapter
|
return PythonNeotestAdapter
|
||||||
end,
|
end,
|
||||||
})
|
})
|
||||||
|
101
lua/neotest-python/pytest.lua
Normal file
101
lua/neotest-python/pytest.lua
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
local lib = require("neotest.lib")
|
||||||
|
local logger = require("neotest.logging")
|
||||||
|
|
||||||
|
local M = {}
|
||||||
|
|
||||||
|
---@async
|
||||||
|
---Add test instances for path in root to positions
|
||||||
|
---@param positions neotest.Tree
|
||||||
|
---@param test_params table<string, string[]>
|
||||||
|
local function add_test_instances(positions, test_params)
|
||||||
|
for _, node in positions:iter_nodes() do
|
||||||
|
local position = node:data()
|
||||||
|
if position.type == "test" then
|
||||||
|
local pos_params = test_params[position.id] or {}
|
||||||
|
for _, params_str in ipairs(pos_params) do
|
||||||
|
local new_data = vim.tbl_extend("force", position, {
|
||||||
|
id = string.format("%s[%s]", position.id, params_str),
|
||||||
|
name = string.format("%s[%s]", position.name, params_str),
|
||||||
|
})
|
||||||
|
new_data.range = nil
|
||||||
|
|
||||||
|
local new_pos = node:new(new_data, {}, node._key, {}, {})
|
||||||
|
node:add_child(new_data.id, new_pos)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
---@async
|
||||||
|
---@param path string
|
||||||
|
---@return boolean
|
||||||
|
local function has_parametrize(path)
|
||||||
|
local query = [[
|
||||||
|
;; Detect parametrize decorators
|
||||||
|
(decorator
|
||||||
|
(call
|
||||||
|
function:
|
||||||
|
(attribute
|
||||||
|
attribute: (identifier) @parametrize
|
||||||
|
(#eq? @parametrize "parametrize"))))
|
||||||
|
]]
|
||||||
|
local content = lib.files.read(path)
|
||||||
|
local ts_root, lang = lib.treesitter.get_parse_root(path, content, { fast = true })
|
||||||
|
local built_query = lib.treesitter.normalise_query(lang, query)
|
||||||
|
return built_query:iter_matches(ts_root, content)() ~= nil
|
||||||
|
end
|
||||||
|
|
||||||
|
---@async
|
||||||
|
---Discover test instances for path (by running script using python)
|
||||||
|
---@param python string[]
|
||||||
|
---@param script string
|
||||||
|
---@param path string
|
||||||
|
---@param positions neotest.Tree
|
||||||
|
---@param root string
|
||||||
|
local function discover_params(python, script, path, positions, root)
|
||||||
|
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
|
||||||
|
logger.debug("Running test instance discovery:", cmd)
|
||||||
|
|
||||||
|
local test_params = {}
|
||||||
|
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
|
||||||
|
if res ~= 0 then
|
||||||
|
logger.warn("Pytest discovery failed")
|
||||||
|
if data.stderr then
|
||||||
|
logger.debug(data.stderr)
|
||||||
|
end
|
||||||
|
return {}
|
||||||
|
end
|
||||||
|
|
||||||
|
for line in vim.gsplit(data.stdout, "\n", true) do
|
||||||
|
local param_index = string.find(line, "[", nil, true)
|
||||||
|
if param_index then
|
||||||
|
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||||
|
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||||
|
|
||||||
|
if positions:get_key(test_id) then
|
||||||
|
if not test_params[test_id] then
|
||||||
|
test_params[test_id] = { param_id }
|
||||||
|
else
|
||||||
|
table.insert(test_params[test_id], param_id)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return test_params
|
||||||
|
end
|
||||||
|
|
||||||
|
---@async
|
||||||
|
---Launch pytest to discover test instances for path, if configured
|
||||||
|
---@param python string[]
|
||||||
|
---@param script string
|
||||||
|
---@param path string
|
||||||
|
---@param positions neotest.Tree
|
||||||
|
---@param root string
|
||||||
|
function M.augment_positions(python, script, path, positions, root)
|
||||||
|
if has_parametrize(path) then
|
||||||
|
local test_params = discover_params(python, script, path, positions, root)
|
||||||
|
add_test_instances(positions, test_params)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
return M
|
@@ -41,6 +41,12 @@ parser.add_argument("args", nargs="*")
|
|||||||
|
|
||||||
|
|
||||||
def main(argv: List[str]):
|
def main(argv: List[str]):
|
||||||
|
if "--pytest-collect" in argv:
|
||||||
|
argv.remove("--pytest-collect")
|
||||||
|
from .pytest import collect
|
||||||
|
collect(argv)
|
||||||
|
return
|
||||||
|
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
adapter = get_adapter(TestRunner(args.runner))
|
adapter = get_adapter(TestRunner(args.runner))
|
||||||
|
|
||||||
|
@@ -106,6 +106,7 @@ class NeotestResultCollector:
|
|||||||
if getattr(item, "callspec", None) is not None:
|
if getattr(item, "callspec", None) is not None:
|
||||||
# Parametrized test
|
# Parametrized test
|
||||||
msg_prefix = f"[{item.callspec.id}] "
|
msg_prefix = f"[{item.callspec.id}] "
|
||||||
|
pos_id += f"[{item.callspec.id}]"
|
||||||
if report.outcome == "failed":
|
if report.outcome == "failed":
|
||||||
exc_repr = report.longrepr
|
exc_repr = report.longrepr
|
||||||
# Test fails due to condition outside of test e.g. xfail
|
# Test fails due to condition outside of test e.g. xfail
|
||||||
@@ -176,3 +177,7 @@ class NeotestDebugpyPlugin:
|
|||||||
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo)
|
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo)
|
||||||
finally:
|
finally:
|
||||||
additional_info.is_tracing -= 1
|
additional_info.is_tracing -= 1
|
||||||
|
|
||||||
|
|
||||||
|
def collect(args):
|
||||||
|
pytest.main(['--collect-only', '-q'] + args)
|
||||||
|
Reference in New Issue
Block a user