diff --git a/lua/neotest-python/adapter.lua b/lua/neotest-python/adapter.lua index 1d9e143..d0b5999 100644 --- a/lua/neotest-python/adapter.lua +++ b/lua/neotest-python/adapter.lua @@ -52,6 +52,7 @@ return function(config) ---@type neotest.Adapter return { + name = "neotest-python", root = base.get_root, filter_dir = function(name) @@ -64,7 +65,7 @@ return function(config) local python_command = config.get_python_command(root) local runner = config.get_runner(python_command) - local positions = lib.treesitter.parse_positions(path, base.treesitter_queries, { + local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), { require_namespaces = runner == "unittest", }) diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index e9c04d9..6d4a3d3 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -92,18 +92,56 @@ function M.get_python_command(root) return python_command_mem[root] end -M.treesitter_queries = [[ +---@return string +function M.get_script_path() + local paths = vim.api.nvim_get_runtime_file("neotest.py", true) + for _, path in ipairs(paths) do + if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then + return path + end + end + + error("neotest.py not found") +end + +---@param python_command string[] +---@param config neotest-python._AdapterConfig +---@param runner string +---@return string +local function scan_test_function_pattern(runner, config, python_command) + local test_function_pattern = "^test" + if runner == "pytest" and config.pytest_discovery then + local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" }) + local _, data = lib.process.run(cmd, { stdout = true, stderr = true }) + + for line in vim.gsplit(data.stdout, "\n", true) do + if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then + local pytest_option = vim.json.decode(line) + test_function_pattern = pytest_option.python_functions + end + end + end + return test_function_pattern +end + +---@param python_command string[] +---@param config neotest-python._AdapterConfig +---@param runner string +---@return string +M.treesitter_queries = function(runner, config, python_command) + local test_function_pattern = scan_test_function_pattern(runner, config, python_command) + return string.format([[ ;; Match undecorated functions ((function_definition name: (identifier) @test.name) - (#match? @test.name "^test")) + (#match? @test.name "%s")) @test.definition ;; Match decorated function, including decorators in definition (decorated_definition ((function_definition name: (identifier) @test.name) - (#match? @test.name "^test"))) + (#match? @test.name "%s"))) @test.definition ;; Match decorated classes, including decorators in definition @@ -120,23 +158,12 @@ M.treesitter_queries = [[ @namespace.definition (#not-has-parent? @namespace.definition decorated_definition) ) - ]] + ]], test_function_pattern, test_function_pattern) +end M.get_root = lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py") ----@return string -function M.get_script_path() - local paths = vim.api.nvim_get_runtime_file("neotest.py", true) - for _, path in ipairs(paths) do - if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then - return path - end - end - - error("neotest.py not found") -end - function M.create_dap_config(python_path, script_path, script_args, dap_args) return vim.tbl_extend("keep", { type = "python", diff --git a/neotest_python/__init__.py b/neotest_python/__init__.py index 9f846e7..f137834 100644 --- a/neotest_python/__init__.py +++ b/neotest_python/__init__.py @@ -58,6 +58,13 @@ def main(argv: List[str]): collect(argv) return + if "--pytest-extract-test-name-template" in argv: + argv.remove("--pytest-extract-test-name-template") + from .pytest import extract_test_name_template + + extract_test_name_template(argv) + return + args = parser.parse_args(argv) adapter = get_adapter(TestRunner(args.runner), args.emit_parameterized_ids) diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index a2371f3..e12beb2 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -1,4 +1,5 @@ from io import StringIO +import json from pathlib import Path from typing import Callable, Dict, List, Optional, Union @@ -205,5 +206,16 @@ class NeotestDebugpyPlugin: additional_info.is_tracing -= 1 +class TestNameTemplateExtractor: + @staticmethod + def pytest_collection_modifyitems(config): + config = {"python_functions": config.getini("python_functions")[0]} + print(f"\n{json.dumps(config)}\n") + + +def extract_test_name_template(args): + pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) + + def collect(args): pytest.main(["--collect-only", "--verbosity=0", "-q"] + args)