diff --git a/lua/neotest-python/adapter.lua b/lua/neotest-python/adapter.lua new file mode 100644 index 0000000..7f1d20e --- /dev/null +++ b/lua/neotest-python/adapter.lua @@ -0,0 +1,139 @@ +local nio = require("nio") +local lib = require("neotest.lib") +local pytest = require("neotest-python.pytest") +local base = require("neotest-python.base") + +---@class neotest-python._AdapterConfig +---@field dap_args? table +---@field pytest_discovery? boolean +---@field is_test_file fun(file_path: string):boolean +---@field get_python_command fun(root: string):string[] +---@field get_args fun(runner: string, position: neotest.Position, strategy: string): string[] +---@field get_runner fun(python_command: string[]): string + +---@param config neotest-python._AdapterConfig +---@return neotest.Adapter +return function(config) + ---@param run_args neotest.RunArgs + ---@param results_path string + ---@param stream_path string + ---@param runner string + ---@return string[] + local function build_script_args(run_args, results_path, stream_path, runner) + local script_args = { + "--results-file", + results_path, + "--stream-file", + stream_path, + "--runner", + runner, + } + + if config.pytest_discovery then + table.insert(script_args, "--emit-parameterized-ids") + end + + local position = run_args.tree:data() + + table.insert(script_args, "--") + + vim.list_extend(script_args, config.get_args(runner, position, run_args.strategy)) + + if run_args.extra_args then + vim.list_extend(script_args, run_args.extra_args) + end + + if position then + table.insert(script_args, position.id) + end + + return script_args + end + + ---@type neotest.Adapter + return { + name = "neotest-python", + root = base.get_root, + filter_dir = function(name) + return name ~= "venv" + end, + is_test_file = config.is_test_file, + discover_positions = function(path) + local root = base.get_root(path) or vim.loop.cwd() or "" + + local python_command = config.get_python_command(root) + local runner = config.get_runner(python_command) + + local positions = lib.treesitter.parse_positions(path, base.treesitter_queries, { + require_namespaces = runner == "unittest", + }) + + if runner == "pytest" and config.pytest_discovery then + pytest.augment_positions(python_command, base.get_script_path(), path, positions, root) + end + + return positions + end, + ---@param args neotest.RunArgs + ---@return neotest.RunSpec + build_spec = function(args) + local position = args.tree:data() + + local root = base.get_root(position.path) or vim.loop.cwd() or "" + + local python_command = config.get_python_command(root) + local runner = config.get_runner(python_command) + + local results_path = nio.fn.tempname() + local stream_path = nio.fn.tempname() + lib.files.write(stream_path, "") + + local stream_data, stop_stream = lib.files.stream_lines(stream_path) + + local script_args = build_script_args(args, results_path, stream_path, runner) + local script_path = base.get_script_path() + + local strategy_config + if args.strategy == "dap" then + strategy_config = + base.create_dap_config(python_command, script_path, script_args, config.dap_args) + P(config, strategy_config) + end + ---@type neotest.RunSpec + return { + command = vim.iter({ python_command, script_path, script_args }):flatten():totable(), + context = { + results_path = results_path, + stop_stream = stop_stream, + }, + stream = function() + return function() + local lines = stream_data() + local results = {} + for _, line in ipairs(lines) do + local result = vim.json.decode(line, { luanil = { object = true } }) + results[result.id] = result.result + end + return results + end + end, + strategy = strategy_config, + } + end, + ---@param spec neotest.RunSpec + ---@param result neotest.StrategyResult + ---@return neotest.Result[] + results = function(spec, result) + spec.context.stop_stream() + local success, data = pcall(lib.files.read, spec.context.results_path) + if not success then + data = "{}" + end + local results = vim.json.decode(data, { luanil = { object = true } }) + for _, pos_result in pairs(results) do + result.output_path = pos_result.output_path + end + return results + end, + } +end diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index c529047..90cd6da 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -1,4 +1,4 @@ -local async = require("neotest.async") +local nio = require("nio") local lib = require("neotest.lib") local Path = require("plenary.path") @@ -14,17 +14,21 @@ function M.is_test_file(file_path) end M.module_exists = function(module, python_command) - return lib.process.run(vim.tbl_flatten({ - python_command, - "-c", - "import " .. module, - })) == 0 + return lib.process.run(vim + .iter({ + python_command, + "-c", + "import " .. module, + }) + :flatten() + :totable()) == 0 end local python_command_mem = {} ---@return string[] function M.get_python_command(root) + root = root or vim.loop.cwd() if python_command_mem[root] then return python_command_mem[root] end @@ -35,7 +39,7 @@ function M.get_python_command(root) end for _, pattern in ipairs({ "*", ".*" }) do - local match = async.fn.glob(Path:new(root or async.fn.getcwd(), pattern, "pyvenv.cfg").filename) + local match = nio.fn.glob(Path:new(root or nio.fn.getcwd(), pattern, "pyvenv.cfg").filename) if match ~= "" then python_command_mem[root] = { (Path:new(match):parent() / "bin" / "python").filename } return python_command_mem[root] @@ -70,9 +74,89 @@ function M.get_python_command(root) -- Fallback to system Python. python_command_mem[root] = { - async.fn.exepath("python3") or async.fn.exepath("python") or "python", + nio.fn.exepath("python3") or nio.fn.exepath("python") or "python", } return python_command_mem[root] end +M.treesitter_queries = [[ + ;; Match undecorated functions + ((function_definition + name: (identifier) @test.name) + (#match? @test.name "^test")) + @test.definition + + ;; Match decorated function, including decorators in definition + (decorated_definition + ((function_definition + name: (identifier) @test.name) + (#match? @test.name "^test"))) + @test.definition + + ;; Match decorated classes, including decorators in definition + (decorated_definition + (class_definition + name: (identifier) @namespace.name)) + @namespace.definition + + ;; Match undecorated classes: namespaces nest so #not-has-parent is used + ;; to ensure each namespace is annotated only once + ( + (class_definition + name: (identifier) @namespace.name) + @namespace.definition + (#not-has-parent? @namespace.definition decorated_definition) + ) + ]] + +M.get_root = + lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py") + +---@return string +function M.get_script_path() + local paths = vim.api.nvim_get_runtime_file("neotest.py", true) + for _, path in ipairs(paths) do + if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then + return path + end + end + + error("neotest.py not found") +end + +function M.create_dap_config(python_path, script_path, script_args, dap_args) + return vim.tbl_extend("keep", { + type = "python", + name = "Neotest Debugger", + request = "launch", + python = python_path, + program = script_path, + cwd = nio.fn.getcwd(), + args = script_args, + }, dap_args or {}) +end + +local stored_runners = {} + +function M.get_runner(python_path) + local command_str = table.concat(python_path, " ") + if stored_runners[command_str] then + return stored_runners[command_str] + end + local vim_test_runner = vim.g["test#python#runner"] + if vim_test_runner == "pyunit" then + return "unittest" + end + if + vim_test_runner and lib.func_util.index({ "unittest", "pytest", "django" }, vim_test_runner) + then + return vim_test_runner + end + local runner = M.module_exists("pytest", python_path) and "pytest" + or M.module_exists("django", python_path) and "django" + or "unittest" + stored_runners[command_str] = runner + return runner +end + return M diff --git a/lua/neotest-python/init.lua b/lua/neotest-python/init.lua index 8d9a750..73cfc6d 100644 --- a/lua/neotest-python/init.lua +++ b/lua/neotest-python/init.lua @@ -1,263 +1,77 @@ -local async = require("neotest.async") -local lib = require("neotest.lib") local base = require("neotest-python.base") -local pytest = require("neotest-python.pytest") +local create_adapter = require("neotest-python.adapter") -local function get_script() - local paths = vim.api.nvim_get_runtime_file("neotest.py", true) - for _, path in ipairs(paths) do - if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then - return path - end - end - - error("neotest.py not found") -end - -local dap_args -local is_test_file = base.is_test_file -local pytest_discover_instances = false - -local function get_strategy_config(strategy, python, program, args) - local config = { - dap = function() - return vim.tbl_extend("keep", { - type = "python", - name = "Neotest Debugger", - request = "launch", - python = python, - program = program, - cwd = async.fn.getcwd(), - args = args, - }, dap_args or {}) - end, - } - if config[strategy] then - return config[strategy]() - end -end - -local get_python = function(root) - if not root then - root = vim.loop.cwd() - end - return base.get_python_command(root) -end - -local get_args = function() - return {} -end - -local stored_runners = {} - -local get_runner = function(python_command) - local command_str = table.concat(python_command, " ") - if stored_runners[command_str] then - return stored_runners[command_str] - end - local vim_test_runner = vim.g["test#python#runner"] - if vim_test_runner == "pyunit" then - return "unittest" - end - if - vim_test_runner and lib.func_util.index({ "unittest", "pytest", "django" }, vim_test_runner) - then - return vim_test_runner - end - local runner = base.module_exists("pytest", python_command) and "pytest" - or base.module_exists("django", python_command) and "django" - or "unittest" - stored_runners[command_str] = runner - return runner -end - ----@type neotest.Adapter -local PythonNeotestAdapter = { name = "neotest-python" } - -PythonNeotestAdapter.root = - lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py") - -function PythonNeotestAdapter.is_test_file(file_path) - return is_test_file(file_path) -end - -function PythonNeotestAdapter.filter_dir(name) - return name ~= "venv" -end - ----@async ----@return neotest.Tree | nil -function PythonNeotestAdapter.discover_positions(path) - local root = PythonNeotestAdapter.root(path) or vim.loop.cwd() - local python = get_python(root) - local runner = get_runner(python) - - -- Parse the file while pytest is running - local query = [[ - ;; Match undecorated functions - ((function_definition - name: (identifier) @test.name) - (#match? @test.name "^test")) - @test.definition - ;; Match decorated function, including decorators in definition - (decorated_definition - ((function_definition - name: (identifier) @test.name) - (#match? @test.name "^test"))) - @test.definition - - ;; Match decorated classes, including decorators in definition - (decorated_definition - (class_definition - name: (identifier) @namespace.name)) - @namespace.definition - ;; Match undecorated classes: namespaces nest so #not-has-parent is used - ;; to ensure each namespace is annotated only once - ( - (class_definition - name: (identifier) @namespace.name) - @namespace.definition - (#not-has-parent? @namespace.definition decorated_definition) - ) - ]] - local positions = lib.treesitter.parse_positions(path, query, { - require_namespaces = runner == "unittest", - }) - - if runner == "pytest" and pytest_discover_instances then - pytest.augment_positions(python, get_script(), path, positions, root) - end - - return positions -end - ----@async ----@param args neotest.RunArgs ----@return neotest.RunSpec -function PythonNeotestAdapter.build_spec(args) - local position = args.tree:data() - local results_path = async.fn.tempname() - local stream_path = async.fn.tempname() - lib.files.write(stream_path, "") - - local root = PythonNeotestAdapter.root(position.path) - local python = get_python(root) - local runner = get_runner(python) - local stream_data, stop_stream = lib.files.stream_lines(stream_path) - local script_args = vim.tbl_flatten({ - "--results-file", - results_path, - "--stream-file", - stream_path, - "--runner", - runner, - }) - if pytest_discover_instances then - table.insert(script_args, "--emit-parameterized-ids") - end - - table.insert(script_args, "--") - vim.list_extend(script_args, get_args(runner, position, args.strategy)) - if args.extra_args then - vim.list_extend(script_args, args.extra_args) - end - - if position then - table.insert(script_args, position.id) - end - local python_script = get_script() - local command = vim.tbl_flatten({ - python, - python_script, - script_args, - }) - local strategy_config = get_strategy_config(args.strategy, python, python_script, script_args) - ---@type neotest.RunSpec - return { - command = command, - context = { - results_path = results_path, - stop_stream = stop_stream, - }, - stream = function() - return function() - local lines = stream_data() - local results = {} - for _, line in ipairs(lines) do - local result = vim.json.decode(line, { luanil = { object = true } }) - results[result.id] = result.result - end - return results - end - end, - strategy = strategy_config, - } -end - ----@async ----@param spec neotest.RunSpec ----@param result neotest.StrategyResult ----@return neotest.Result[] -function PythonNeotestAdapter.results(spec, result) - spec.context.stop_stream() - local success, data = pcall(lib.files.read, spec.context.results_path) - if not success then - data = "{}" - end - -- TODO: Find out if this JSON option is supported in future - local results = vim.json.decode(data, { luanil = { object = true } }) - for _, pos_result in pairs(results) do - result.output_path = pos_result.output_path - end - return results -end +---@class neotest-python.AdapterConfig +---@field dap? table +---@field pytest_discover_instances? boolean +---@field is_test_file? fun(file_path: string):boolean +---@field python? string|string[]|fun(root: string):string[] +---@field args? string[]|fun(runner: string, position: neotest.Position, strategy: string): string[] +---@field runner? string|fun(python_command: string[]): string local is_callable = function(obj) return type(obj) == "function" or (type(obj) == "table" and obj.__call) end +---@param config neotest-python.AdapterConfig +local augment_config = function(config) + local get_python_command = base.get_python_command + if config.python then + get_python_command = function(root) + local python = config.python + + if is_callable(config.python) then + python = config.python(root) + end + + if type(python) == "string" then + return { python } + end + if type(python) == "table" then + return python + end + + return base.get_python(root) + end + end + + local get_args = function() + return {} + end + + if is_callable(config.args) then + get_args = config.args + elseif config.args then + get_args = function() + return config.args + end + end + + local get_runner = base.get_runner + if is_callable(config.runner) then + get_runner = config.runner + elseif config.runner then + get_runner = function() + return config.runner + end + end + + ---@type neotest-python._AdapterConfig + return { + pytest_discovery = config.pytest_discover_instances, + dap_args = config.dap, + get_runner = get_runner, + get_args = get_args, + is_test_file = config.is_test_file or base.is_test_file, + get_python_command = get_python_command, + } +end + +local PythonNeotestAdapter = create_adapter(augment_config({})) + setmetatable(PythonNeotestAdapter, { - __call = function(_, opts) - is_test_file = opts.is_test_file or is_test_file - if opts.python then - get_python = function(root) - local python = opts.python - - if is_callable(opts.python) then - python = opts.python(root) - end - - if type(python) == "string" then - return { python } - end - if type(python) == "table" then - return python - end - - return base.get_python(root) - end - end - if is_callable(opts.args) then - get_args = opts.args - elseif opts.args then - get_args = function() - return opts.args - end - end - if is_callable(opts.runner) then - get_runner = opts.runner - elseif opts.runner then - get_runner = function() - return opts.runner - end - end - if type(opts.dap) == "table" then - dap_args = opts.dap - end - if opts.pytest_discover_instances ~= nil then - pytest_discover_instances = opts.pytest_discover_instances - end - return PythonNeotestAdapter + __call = function(_, config) + return create_adapter(augment_config(config)) end, })