feat(pytest): populate parameterized test instances (#36)
Co-authored-by: Rónán Carrigan <rcarriga@tcd.ie>
This commit is contained in:
@@ -38,7 +38,9 @@ require("neotest").setup({
|
||||
is_test_file = function(file_path)
|
||||
...
|
||||
end,
|
||||
|
||||
-- !!EXPERIMENTAL!! Enable shelling out to `pytest` to discover test
|
||||
-- instances for files containing a parametrize mark (default: false)
|
||||
pytest_discover_instances = true,
|
||||
})
|
||||
}
|
||||
})
|
||||
|
@@ -1,6 +1,7 @@
|
||||
local async = require("neotest.async")
|
||||
local lib = require("neotest.lib")
|
||||
local base = require("neotest-python.base")
|
||||
local pytest = require("neotest-python.pytest")
|
||||
|
||||
local function get_script()
|
||||
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
|
||||
@@ -15,6 +16,7 @@ end
|
||||
|
||||
local dap_args
|
||||
local is_test_file = base.is_test_file
|
||||
local pytest_discover_instances = false
|
||||
|
||||
local function get_strategy_config(strategy, python, program, args)
|
||||
local config = {
|
||||
@@ -80,8 +82,13 @@ function PythonNeotestAdapter.filter_dir(name)
|
||||
end
|
||||
|
||||
---@async
|
||||
---@return Tree | nil
|
||||
---@return neotest.Tree | nil
|
||||
function PythonNeotestAdapter.discover_positions(path)
|
||||
local root = PythonNeotestAdapter.root(path) or vim.loop.cwd()
|
||||
local python = get_python(root)
|
||||
local runner = get_runner(python)
|
||||
|
||||
-- Parse the file while pytest is running
|
||||
local query = [[
|
||||
;; Match undecorated functions
|
||||
((function_definition
|
||||
@@ -109,12 +116,15 @@ function PythonNeotestAdapter.discover_positions(path)
|
||||
(#not-has-parent? @namespace.definition decorated_definition)
|
||||
)
|
||||
]]
|
||||
local root = PythonNeotestAdapter.root(path)
|
||||
local python = get_python(root)
|
||||
local runner = get_runner(python)
|
||||
return lib.treesitter.parse_positions(path, query, {
|
||||
local positions = lib.treesitter.parse_positions(path, query, {
|
||||
require_namespaces = runner == "unittest",
|
||||
})
|
||||
|
||||
if runner == "pytest" and pytest_discover_instances then
|
||||
pytest.augment_positions(python, get_script(), path, positions, root)
|
||||
end
|
||||
|
||||
return positions
|
||||
end
|
||||
|
||||
---@async
|
||||
@@ -232,6 +242,9 @@ setmetatable(PythonNeotestAdapter, {
|
||||
if type(opts.dap) == "table" then
|
||||
dap_args = opts.dap
|
||||
end
|
||||
if opts.pytest_discover_instances ~= nil then
|
||||
pytest_discover_instances = opts.pytest_discover_instances
|
||||
end
|
||||
return PythonNeotestAdapter
|
||||
end,
|
||||
})
|
||||
|
101
lua/neotest-python/pytest.lua
Normal file
101
lua/neotest-python/pytest.lua
Normal file
@@ -0,0 +1,101 @@
|
||||
local lib = require("neotest.lib")
|
||||
local logger = require("neotest.logging")
|
||||
|
||||
local M = {}
|
||||
|
||||
---@async
|
||||
---Add test instances for path in root to positions
|
||||
---@param positions neotest.Tree
|
||||
---@param test_params table<string, string[]>
|
||||
local function add_test_instances(positions, test_params)
|
||||
for _, node in positions:iter_nodes() do
|
||||
local position = node:data()
|
||||
if position.type == "test" then
|
||||
local pos_params = test_params[position.id] or {}
|
||||
for _, params_str in ipairs(pos_params) do
|
||||
local new_data = vim.tbl_extend("force", position, {
|
||||
id = string.format("%s[%s]", position.id, params_str),
|
||||
name = string.format("%s[%s]", position.name, params_str),
|
||||
})
|
||||
new_data.range = nil
|
||||
|
||||
local new_pos = node:new(new_data, {}, node._key, {}, {})
|
||||
node:add_child(new_data.id, new_pos)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
---@async
|
||||
---@param path string
|
||||
---@return boolean
|
||||
local function has_parametrize(path)
|
||||
local query = [[
|
||||
;; Detect parametrize decorators
|
||||
(decorator
|
||||
(call
|
||||
function:
|
||||
(attribute
|
||||
attribute: (identifier) @parametrize
|
||||
(#eq? @parametrize "parametrize"))))
|
||||
]]
|
||||
local content = lib.files.read(path)
|
||||
local ts_root, lang = lib.treesitter.get_parse_root(path, content, { fast = true })
|
||||
local built_query = lib.treesitter.normalise_query(lang, query)
|
||||
return built_query:iter_matches(ts_root, content)() ~= nil
|
||||
end
|
||||
|
||||
---@async
|
||||
---Discover test instances for path (by running script using python)
|
||||
---@param python string[]
|
||||
---@param script string
|
||||
---@param path string
|
||||
---@param positions neotest.Tree
|
||||
---@param root string
|
||||
local function discover_params(python, script, path, positions, root)
|
||||
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
|
||||
logger.debug("Running test instance discovery:", cmd)
|
||||
|
||||
local test_params = {}
|
||||
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
|
||||
if res ~= 0 then
|
||||
logger.warn("Pytest discovery failed")
|
||||
if data.stderr then
|
||||
logger.debug(data.stderr)
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
for line in vim.gsplit(data.stdout, "\n", true) do
|
||||
local param_index = string.find(line, "[", nil, true)
|
||||
if param_index then
|
||||
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||
|
||||
if positions:get_key(test_id) then
|
||||
if not test_params[test_id] then
|
||||
test_params[test_id] = { param_id }
|
||||
else
|
||||
table.insert(test_params[test_id], param_id)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return test_params
|
||||
end
|
||||
|
||||
---@async
|
||||
---Launch pytest to discover test instances for path, if configured
|
||||
---@param python string[]
|
||||
---@param script string
|
||||
---@param path string
|
||||
---@param positions neotest.Tree
|
||||
---@param root string
|
||||
function M.augment_positions(python, script, path, positions, root)
|
||||
if has_parametrize(path) then
|
||||
local test_params = discover_params(python, script, path, positions, root)
|
||||
add_test_instances(positions, test_params)
|
||||
end
|
||||
end
|
||||
|
||||
return M
|
@@ -41,6 +41,12 @@ parser.add_argument("args", nargs="*")
|
||||
|
||||
|
||||
def main(argv: List[str]):
|
||||
if "--pytest-collect" in argv:
|
||||
argv.remove("--pytest-collect")
|
||||
from .pytest import collect
|
||||
collect(argv)
|
||||
return
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
adapter = get_adapter(TestRunner(args.runner))
|
||||
|
||||
|
@@ -106,6 +106,7 @@ class NeotestResultCollector:
|
||||
if getattr(item, "callspec", None) is not None:
|
||||
# Parametrized test
|
||||
msg_prefix = f"[{item.callspec.id}] "
|
||||
pos_id += f"[{item.callspec.id}]"
|
||||
if report.outcome == "failed":
|
||||
exc_repr = report.longrepr
|
||||
# Test fails due to condition outside of test e.g. xfail
|
||||
@@ -176,3 +177,7 @@ class NeotestDebugpyPlugin:
|
||||
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo)
|
||||
finally:
|
||||
additional_info.is_tracing -= 1
|
||||
|
||||
|
||||
def collect(args):
|
||||
pytest.main(['--collect-only', '-q'] + args)
|
||||
|
Reference in New Issue
Block a user