feat(pytest): use socket instead of a shitton of processes

This commit is contained in:
Itai Bohadana
2025-08-26 15:48:59 +03:00
parent ed9b4d794b
commit 9f6fbd6e04
8 changed files with 401 additions and 50 deletions

View File

@@ -65,9 +65,13 @@ return function(config)
local python_command = config.get_python_command(root)
local runner = config.get_runner(python_command)
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
require_namespaces = runner == "unittest",
})
local positions = lib.treesitter.parse_positions(
path,
base.treesitter_queries(runner, config, python_command),
{
require_namespaces = runner == "unittest",
}
)
if runner == "pytest" and config.pytest_discovery then
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)

View File

@@ -96,7 +96,7 @@ end
function M.get_script_path()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
for _, path in ipairs(paths) do
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
return path
end
end
@@ -110,17 +110,6 @@ end
---@return string
local function scan_test_function_pattern(runner, config, python_command)
local test_function_pattern = "^test"
if runner == "pytest" and config.pytest_discovery then
local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" })
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })
for line in vim.gsplit(data.stdout, "\n", true) do
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
local pytest_option = vim.json.decode(line)
test_function_pattern = pytest_option.python_functions
end
end
end
return test_function_pattern
end
@@ -130,7 +119,8 @@ end
---@return string
M.treesitter_queries = function(runner, config, python_command)
local test_function_pattern = scan_test_function_pattern(runner, config, python_command)
return string.format([[
return string.format(
[[
;; Match undecorated functions
((function_definition
name: (identifier) @test.name)
@@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command)
@namespace.definition
(#not-has-parent? @namespace.definition decorated_definition)
)
]], test_function_pattern, test_function_pattern)
]],
test_function_pattern,
test_function_pattern
)
end
M.get_root =
@@ -192,7 +185,7 @@ function M.get_runner(python_path)
then
return vim_test_runner
end
local runner = M.module_exists("pytest", python_path) and "pytest"
local runner = M.module_exists("pytest_", python_path) and "pytest"
or M.module_exists("django", python_path) and "django"
or "unittest"
stored_runners[command_str] = runner

View File

@@ -1,4 +1,5 @@
local lib = require("neotest.lib")
local nio = require("nio")
local logger = require("neotest.logging")
local M = {}
@@ -15,7 +16,7 @@ local function add_test_instances(positions, test_params)
for _, params_str in ipairs(pos_params) do
local new_data = vim.tbl_extend("force", position, {
id = string.format("%s[%s]", position.id, params_str),
name = string.format("%s[%s]", position.name, params_str),
name = string.format("%s", params_str),
})
new_data.range = nil
@@ -26,6 +27,73 @@ local function add_test_instances(positions, test_params)
end
end
local socket_path = ""
local socket_future = nio.control.future()
local socket_start = nio.control.future()
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @param callback uv.read_start.callback: One or more messages to send
-- @return string: Concatenated response from server
local function get_socket_path(cmd, messages, callback)
-- 1. Run the command and capture its output (socket path)
local stdout = assert(vim.uv.new_pipe())
local stderr = assert(vim.uv.new_pipe())
local stdin = assert(vim.uv.new_pipe())
if socket_path == "" and not socket_start.is_set() then
socket_start.set()
local handle
handle, _ = vim.uv.spawn(cmd[1], {
stdio = { stdin, stdout, stderr },
detached = false,
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
}, function(code, signal)
vim.uv.close(stdout)
vim.uv.close(stderr)
vim.uv.close(stdin)
vim.uv.close(handle)
end)
vim.uv.read_start(stdout, function(err, s)
if s ~= nil then
socket_path = string.gsub(s, "\n$", "")
socket_future.set()
end
end)
vim.uv.read_start(stderr, function(err, s)
if err ~= nil then
vim.print(err, s)
end
end)
end
socket_future.wait()
-- 2. Connect to the unix socket
local client = assert(vim.uv.new_pipe(false))
client:connect(socket_path, function(err)
if err ~= nil then
vim.print("Error", err)
end
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
vim.uv.write(client, msg .. "\n", function(err_write)
client:read_start(function(err_read, data)
if err_read ~= nil then
vim.print(err_read, data)
else
callback(err_read, data)
end
end)
end)
end
end)
end
---@async
---@param path string
---@return boolean
@@ -57,31 +125,32 @@ local function discover_params(python, script, path, positions, root)
logger.debug("Running test instance discovery:", cmd)
local test_params = {}
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
if res ~= 0 then
logger.warn("Pytest discovery failed")
if data.stderr then
logger.debug(data.stderr)
get_socket_path(cmd, path, function(err, data)
if err ~= nil then
vim.print(err, data)
return
end
if data == nil then
return
end
return {}
end
for line in vim.gsplit(data.stdout, "\n", true) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
end
end
end
end
end
return test_params
return add_test_instances(positions, test_params)
end)
end
---@async
@@ -93,8 +162,7 @@ end
---@param root string
function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then
local test_params = discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
discover_params(python, script, path, positions, root)
end
end