feat(pytest): use socket instead of a shitton of processes
This commit is contained in:
@@ -65,9 +65,13 @@ return function(config)
|
||||
local python_command = config.get_python_command(root)
|
||||
local runner = config.get_runner(python_command)
|
||||
|
||||
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
|
||||
require_namespaces = runner == "unittest",
|
||||
})
|
||||
local positions = lib.treesitter.parse_positions(
|
||||
path,
|
||||
base.treesitter_queries(runner, config, python_command),
|
||||
{
|
||||
require_namespaces = runner == "unittest",
|
||||
}
|
||||
)
|
||||
|
||||
if runner == "pytest" and config.pytest_discovery then
|
||||
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)
|
||||
|
||||
@@ -96,7 +96,7 @@ end
|
||||
function M.get_script_path()
|
||||
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
|
||||
for _, path in ipairs(paths) do
|
||||
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
|
||||
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
|
||||
return path
|
||||
end
|
||||
end
|
||||
@@ -110,17 +110,6 @@ end
|
||||
---@return string
|
||||
local function scan_test_function_pattern(runner, config, python_command)
|
||||
local test_function_pattern = "^test"
|
||||
if runner == "pytest" and config.pytest_discovery then
|
||||
local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" })
|
||||
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })
|
||||
|
||||
for line in vim.gsplit(data.stdout, "\n", true) do
|
||||
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
|
||||
local pytest_option = vim.json.decode(line)
|
||||
test_function_pattern = pytest_option.python_functions
|
||||
end
|
||||
end
|
||||
end
|
||||
return test_function_pattern
|
||||
end
|
||||
|
||||
@@ -130,7 +119,8 @@ end
|
||||
---@return string
|
||||
M.treesitter_queries = function(runner, config, python_command)
|
||||
local test_function_pattern = scan_test_function_pattern(runner, config, python_command)
|
||||
return string.format([[
|
||||
return string.format(
|
||||
[[
|
||||
;; Match undecorated functions
|
||||
((function_definition
|
||||
name: (identifier) @test.name)
|
||||
@@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command)
|
||||
@namespace.definition
|
||||
(#not-has-parent? @namespace.definition decorated_definition)
|
||||
)
|
||||
]], test_function_pattern, test_function_pattern)
|
||||
]],
|
||||
test_function_pattern,
|
||||
test_function_pattern
|
||||
)
|
||||
end
|
||||
|
||||
M.get_root =
|
||||
@@ -192,7 +185,7 @@ function M.get_runner(python_path)
|
||||
then
|
||||
return vim_test_runner
|
||||
end
|
||||
local runner = M.module_exists("pytest", python_path) and "pytest"
|
||||
local runner = M.module_exists("pytest_", python_path) and "pytest"
|
||||
or M.module_exists("django", python_path) and "django"
|
||||
or "unittest"
|
||||
stored_runners[command_str] = runner
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
local lib = require("neotest.lib")
|
||||
local nio = require("nio")
|
||||
local logger = require("neotest.logging")
|
||||
|
||||
local M = {}
|
||||
@@ -15,7 +16,7 @@ local function add_test_instances(positions, test_params)
|
||||
for _, params_str in ipairs(pos_params) do
|
||||
local new_data = vim.tbl_extend("force", position, {
|
||||
id = string.format("%s[%s]", position.id, params_str),
|
||||
name = string.format("%s[%s]", position.name, params_str),
|
||||
name = string.format("%s", params_str),
|
||||
})
|
||||
new_data.range = nil
|
||||
|
||||
@@ -26,6 +27,73 @@ local function add_test_instances(positions, test_params)
|
||||
end
|
||||
end
|
||||
|
||||
local socket_path = ""
|
||||
local socket_future = nio.control.future()
|
||||
local socket_start = nio.control.future()
|
||||
--- Run a command, connect to the UNIX socket it prints, send messages, return response
|
||||
-- @param cmd string: Command to run (should output socket path)
|
||||
-- @param messages table|string: One or more messages to send
|
||||
-- @param callback uv.read_start.callback: One or more messages to send
|
||||
-- @return string: Concatenated response from server
|
||||
local function get_socket_path(cmd, messages, callback)
|
||||
-- 1. Run the command and capture its output (socket path)
|
||||
local stdout = assert(vim.uv.new_pipe())
|
||||
local stderr = assert(vim.uv.new_pipe())
|
||||
local stdin = assert(vim.uv.new_pipe())
|
||||
if socket_path == "" and not socket_start.is_set() then
|
||||
socket_start.set()
|
||||
local handle
|
||||
handle, _ = vim.uv.spawn(cmd[1], {
|
||||
stdio = { stdin, stdout, stderr },
|
||||
detached = false,
|
||||
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
|
||||
}, function(code, signal)
|
||||
vim.uv.close(stdout)
|
||||
vim.uv.close(stderr)
|
||||
vim.uv.close(stdin)
|
||||
vim.uv.close(handle)
|
||||
end)
|
||||
|
||||
vim.uv.read_start(stdout, function(err, s)
|
||||
if s ~= nil then
|
||||
socket_path = string.gsub(s, "\n$", "")
|
||||
socket_future.set()
|
||||
end
|
||||
end)
|
||||
|
||||
vim.uv.read_start(stderr, function(err, s)
|
||||
if err ~= nil then
|
||||
vim.print(err, s)
|
||||
end
|
||||
end)
|
||||
end
|
||||
socket_future.wait()
|
||||
|
||||
-- 2. Connect to the unix socket
|
||||
local client = assert(vim.uv.new_pipe(false))
|
||||
|
||||
client:connect(socket_path, function(err)
|
||||
if err ~= nil then
|
||||
vim.print("Error", err)
|
||||
end
|
||||
-- 3. Send message(s)
|
||||
if type(messages) == "string" then
|
||||
messages = { messages }
|
||||
end
|
||||
for _, msg in ipairs(messages) do
|
||||
vim.uv.write(client, msg .. "\n", function(err_write)
|
||||
client:read_start(function(err_read, data)
|
||||
if err_read ~= nil then
|
||||
vim.print(err_read, data)
|
||||
else
|
||||
callback(err_read, data)
|
||||
end
|
||||
end)
|
||||
end)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
---@async
|
||||
---@param path string
|
||||
---@return boolean
|
||||
@@ -57,31 +125,32 @@ local function discover_params(python, script, path, positions, root)
|
||||
logger.debug("Running test instance discovery:", cmd)
|
||||
|
||||
local test_params = {}
|
||||
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
|
||||
if res ~= 0 then
|
||||
logger.warn("Pytest discovery failed")
|
||||
if data.stderr then
|
||||
logger.debug(data.stderr)
|
||||
get_socket_path(cmd, path, function(err, data)
|
||||
if err ~= nil then
|
||||
vim.print(err, data)
|
||||
return
|
||||
end
|
||||
if data == nil then
|
||||
return
|
||||
end
|
||||
return {}
|
||||
end
|
||||
|
||||
for line in vim.gsplit(data.stdout, "\n", true) do
|
||||
local param_index = string.find(line, "[", nil, true)
|
||||
if param_index then
|
||||
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
|
||||
local param_index = string.find(line, "[", nil, true)
|
||||
if param_index then
|
||||
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||
|
||||
if positions:get_key(test_id) then
|
||||
if not test_params[test_id] then
|
||||
test_params[test_id] = { param_id }
|
||||
else
|
||||
table.insert(test_params[test_id], param_id)
|
||||
if positions:get_key(test_id) then
|
||||
if not test_params[test_id] then
|
||||
test_params[test_id] = { param_id }
|
||||
else
|
||||
table.insert(test_params[test_id], param_id)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
return test_params
|
||||
return add_test_instances(positions, test_params)
|
||||
end)
|
||||
end
|
||||
|
||||
---@async
|
||||
@@ -93,8 +162,7 @@ end
|
||||
---@param root string
|
||||
function M.augment_positions(python, script, path, positions, root)
|
||||
if has_parametrize(path) then
|
||||
local test_params = discover_params(python, script, path, positions, root)
|
||||
add_test_instances(positions, test_params)
|
||||
discover_params(python, script, path, positions, root)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
Reference in New Issue
Block a user