Compare commits

..

1 Commits

Author SHA1 Message Date
Itai Bohadana
ff75e11bca feat(pytest): use socket instead of a shitton of processes 2025-08-26 15:48:59 +03:00
6 changed files with 118 additions and 237 deletions

View File

@@ -65,13 +65,9 @@ return function(config)
local python_command = config.get_python_command(root) local python_command = config.get_python_command(root)
local runner = config.get_runner(python_command) local runner = config.get_runner(python_command)
local positions = lib.treesitter.parse_positions( local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
path, require_namespaces = runner == "unittest",
base.treesitter_queries(runner, config, python_command), })
{
require_namespaces = runner == "unittest",
}
)
if runner == "pytest" and config.pytest_discovery then if runner == "pytest" and config.pytest_discovery then
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root) pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)

View File

@@ -96,7 +96,7 @@ end
function M.get_script_path() function M.get_script_path()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true) local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
for _, path in ipairs(paths) do for _, path in ipairs(paths) do
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
return path return path
end end
end end

View File

@@ -1,9 +1,54 @@
local lib = require("neotest.lib") local lib = require("neotest.lib")
local nio = require("nio")
local logger = require("neotest.logging") local logger = require("neotest.logging")
local M = {} local M = {}
local unix = require("socket.unix")
local socket_path = ""
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @return string: Concatenated response from server
local function talk_unix(cmd, messages)
-- 1. Run the command and capture its output (socket path)
if not socket_path then
local handle = assert(io.popen(cmd, "r"))
socket_path = handle:read("*l")
handle:close()
assert(socket_path, "Command did not return a socket path")
end
-- 2. Connect to the unix socket
local client = assert(unix())
assert(client:connect(socket_path))
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
assert(client:send(msg .. "\n"))
end
-- 4. Read response until EOF or timeout
client:settimeout(1)
local chunks = {}
while true do
local data, err = client:receive("*l")
if not data then
if err ~= "timeout" and err ~= "closed" then
error("Socket receive error: " .. tostring(err))
end
break
end
table.insert(chunks, data)
end
client:close()
return table.concat(chunks, "\n")
end
---@async ---@async
---Add test instances for path in root to positions ---Add test instances for path in root to positions
---@param positions neotest.Tree ---@param positions neotest.Tree
@@ -16,7 +61,7 @@ local function add_test_instances(positions, test_params)
for _, params_str in ipairs(pos_params) do for _, params_str in ipairs(pos_params) do
local new_data = vim.tbl_extend("force", position, { local new_data = vim.tbl_extend("force", position, {
id = string.format("%s[%s]", position.id, params_str), id = string.format("%s[%s]", position.id, params_str),
name = string.format("%s", params_str), name = string.format("%s[%s]", position.name, params_str),
}) })
new_data.range = nil new_data.range = nil
@@ -27,74 +72,6 @@ local function add_test_instances(positions, test_params)
end end
end end
local socket_path = ""
local socket_future = nio.control.future()
local socket_start = nio.control.future()
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @param callback uv.read_start.callback: One or more messages to send
-- @return string: Concatenated response from server
local function get_socket_path(cmd, messages, callback)
-- 1. Run the command and capture its output (socket path)
if socket_path == "" and not socket_start.is_set() then
local stdout = assert(vim.uv.new_pipe(true))
local stderr = assert(vim.uv.new_pipe(true))
local stdin = assert(vim.uv.new_pipe(true))
socket_start.set()
local handle
handle, _ = vim.uv.spawn(cmd[1], {
stdio = { stdin, stdout, stderr },
detached = true,
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
}, function(code, signal)
vim.uv.close(stdout)
vim.uv.close(stderr)
vim.uv.close(stdin)
vim.uv.close(handle)
end)
vim.uv.read_start(stdout, function(err, s)
if s ~= nil then
socket_path = string.gsub(s, "\n$", "")
socket_future.set()
end
end)
vim.uv.read_start(stderr, function(err, s)
if err ~= nil then
vim.print(err, s)
end
end)
end
socket_future.wait()
-- 2. Connect to the unix socket
local client = assert(vim.uv.new_pipe(false))
client:connect(socket_path, function(err)
if err ~= nil then
vim.print("Error", err)
end
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
vim.uv.write(client, msg .. "\n", function(err_write)
client:read_start(function(err_read, data)
if err_read ~= nil then
vim.print(err_read, data)
else
callback(err_read, data)
end
client:close()
end)
end)
end
end)
end
---@async ---@async
---@param path string ---@param path string
---@return boolean ---@return boolean
@@ -122,36 +99,28 @@ end
---@param positions neotest.Tree ---@param positions neotest.Tree
---@param root string ---@param root string
local function discover_params(python, script, path, positions, root) local function discover_params(python, script, path, positions, root)
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path }) local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", "-s", path })
logger.debug("Running test instance discovery:", cmd) logger.debug("Running test instance discovery:", cmd)
local test_params = {} local test_params = {}
get_socket_path(cmd, path, function(err, data) local data = talk_unix(cmd, path)
if err ~= nil then
vim.print(err, data)
return
end
if data == nil then
return
end
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do for line in vim.gsplit(data, "\n", true) do
local param_index = string.find(line, "[", nil, true) local param_index = string.find(line, "[", nil, true)
if param_index then if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1) local param_id = string.sub(line, param_index + 1, #line - 1)
if positions:get_key(test_id) then if positions:get_key(test_id) then
if not test_params[test_id] then if not test_params[test_id] then
test_params[test_id] = { param_id } test_params[test_id] = { param_id }
else else
table.insert(test_params[test_id], param_id) table.insert(test_params[test_id], param_id)
end
end end
end end
end end
return add_test_instances(positions, test_params) end
end) return test_params
end end
---@async ---@async
@@ -163,7 +132,8 @@ end
---@param root string ---@param root string
function M.augment_positions(python, script, path, positions, root) function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then if has_parametrize(path) then
discover_params(python, script, path, positions, root) local test_params = discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
end end
end end

View File

@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter: def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST: if runner == TestRunner.PYTEST:
from .pytest_ import PytestNeotestAdapter from .pytest import PytestNeotestAdapter
return PytestNeotestAdapter(emit_parameterized_ids) return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST: elif runner == TestRunner.UNITTEST:
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]): def main(argv: List[str]):
if "--pytest-collect" in argv: if "--pytest-collect" in argv:
argv.remove("--pytest-collect") argv.remove("--pytest-collect")
from .pytest_ import collect from .pytest import collect
collect(argv) collect(argv)
return return
if "--pytest-extract-test-name-template" in argv: if "--pytest-extract-test-name-template" in argv:
argv.remove("--pytest-extract-test-name-template") argv.remove("--pytest-extract-test-name-template")
from .pytest_ import extract_test_name_template from .pytest import extract_test_name_template
extract_test_name_template(argv) extract_test_name_template(argv)
return return

View File

@@ -1,69 +1,30 @@
import asyncio
import atexit import atexit
from collections import deque
from collections.abc import Iterable from collections.abc import Iterable
import hashlib import hashlib
import itertools
import logging
import os import os
import select
import signal import signal
import socket import socket
import sys import sys
from _pytest.mark.structures import ParameterSet
from _pytest.python import IdMaker
import pytest import pytest
import importlib import importlib
import importlib.util import importlib.util
import argparse import argparse
import inspect import inspect
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, Self, cast from typing import Any, Callable, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python") SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver: def get_tests(paths: Iterable[str]) -> list[str]:
def __init__(self, s: socket.socket) -> None: test_names: list[str] = []
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
for path in (Path(path) for path in paths): for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path) spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None: if spec is None or spec.loader is None:
raise ModuleNotFoundError(path) raise ModuleNotFoundError
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod)
@@ -75,58 +36,32 @@ def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
) )
for _, test in tests: for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)): if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}") test_names.append(f"{test.__module__}::{test.__name__}")
continue continue
for mark in cast(Iterable[pytest.Mark], marks): for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize": if mark.name == "parametrize":
ids = mark.kwargs.get("ids") ids = mark.kwargs.get("ids", [])
argnames = mark.args[0] for id_, params in itertools.zip_longest(ids, mark.args[1]):
argvalues = mark.args[1] id_ = getattr(params, "id", None) or id_ or repr(params)
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage] test_names.append(f"{test.__module__}::{test.__name__}[{id_}]")
argnames,
argvalues,
None,
None, # pyright: ignore[reportArgumentType]
None, # pyright: ignore[reportArgumentType]
)
if ids is None:
idfn = None
ids_ = None
elif callable(ids):
idfn = ids
ids_ = None
else:
idfn = None
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
None, # pyright: ignore[reportArgumentType]
ids,
parametersets,
test.__name__, # pyright: ignore[reportArgumentType]
)
id_maker = IdMaker( return test_names
argnames,
parametersets,
idfn,
ids_,
None,
nodeid=None,
func_name=test.__name__,
)
yield from (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
for id_ in id_maker.make_unique_parameterset_ids()
)
def _close_socket(path: Path) -> None: def main(
if path.exists(): paths: Iterable[str],
os.unlink(path) collect_only: bool = True,
exit(0) quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
):
tests = get_tests(paths)
if tests:
print("\n".join(tests))
if not socket_mode:
return
async def serve_socket():
if not SOCKET_ROOT_DIR.exists(): if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir() SOCKET_ROOT_DIR.mkdir()
@@ -138,43 +73,29 @@ async def serve_socket():
print(python_socket_path) print(python_socket_path)
return return
atexit.register(lambda: _close_socket(python_socket_path)) child_pid = os.fork()
if child_pid != 0:
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
return return
asyncio.run(serve_socket())
atexit.register(lambda: os.unlink(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: os.unlink(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: os.unlink(python_socket_path))
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(python_socket_path.absolute().as_posix())
sock.listen()
print(python_socket_path)
while ar := sock.accept():
con, _ = ar
try:
paths = con.recv(1000).decode().splitlines()
tests = get_tests(paths)
con.send("\n".join(tests).encode())
except Exception:
logging.exception("Error")
if __name__ == "__main__": if __name__ == "__main__":
@@ -183,7 +104,6 @@ if __name__ == "__main__":
parser.add_argument("-q", dest="quiet", action="store_true") parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0) parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true") parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument( parser.add_argument(
"paths", "paths",
nargs="*", nargs="*",
@@ -196,5 +116,4 @@ if __name__ == "__main__":
collect_only=args.collect_only, collect_only=args.collect_only,
verbosity=args.verbosity, verbosity=args.verbosity,
socket_mode=args.socket_mode, socket_mode=args.socket_mode,
no_fork=args.no_fork,
) )

View File

@@ -2,7 +2,7 @@ from io import StringIO
import json import json
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
from . import params_getter import params_getter
import pytest import pytest
from _pytest._code.code import ExceptionRepr from _pytest._code.code import ExceptionRepr
@@ -220,8 +220,4 @@ def extract_test_name_template(args):
def collect(args): def collect(args):
params_getter.main( params_getter.main([], socket_mode=True)
[],
socket_mode=True,
no_fork=True,
)