Compare commits

..

1 Commits

Author SHA1 Message Date
Itai Bohadana
ff75e11bca feat(pytest): use socket instead of a shitton of processes 2025-08-26 15:48:59 +03:00
6 changed files with 118 additions and 237 deletions

View File

@@ -65,13 +65,9 @@ return function(config)
local python_command = config.get_python_command(root)
local runner = config.get_runner(python_command)
local positions = lib.treesitter.parse_positions(
path,
base.treesitter_queries(runner, config, python_command),
{
require_namespaces = runner == "unittest",
}
)
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
require_namespaces = runner == "unittest",
})
if runner == "pytest" and config.pytest_discovery then
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)

View File

@@ -96,7 +96,7 @@ end
function M.get_script_path()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
for _, path in ipairs(paths) do
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
return path
end
end

View File

@@ -1,9 +1,54 @@
local lib = require("neotest.lib")
local nio = require("nio")
local logger = require("neotest.logging")
local M = {}
local unix = require("socket.unix")
local socket_path = ""
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @return string: Concatenated response from server
local function talk_unix(cmd, messages)
-- 1. Run the command and capture its output (socket path)
if not socket_path then
local handle = assert(io.popen(cmd, "r"))
socket_path = handle:read("*l")
handle:close()
assert(socket_path, "Command did not return a socket path")
end
-- 2. Connect to the unix socket
local client = assert(unix())
assert(client:connect(socket_path))
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
assert(client:send(msg .. "\n"))
end
-- 4. Read response until EOF or timeout
client:settimeout(1)
local chunks = {}
while true do
local data, err = client:receive("*l")
if not data then
if err ~= "timeout" and err ~= "closed" then
error("Socket receive error: " .. tostring(err))
end
break
end
table.insert(chunks, data)
end
client:close()
return table.concat(chunks, "\n")
end
---@async
---Add test instances for path in root to positions
---@param positions neotest.Tree
@@ -16,7 +61,7 @@ local function add_test_instances(positions, test_params)
for _, params_str in ipairs(pos_params) do
local new_data = vim.tbl_extend("force", position, {
id = string.format("%s[%s]", position.id, params_str),
name = string.format("%s", params_str),
name = string.format("%s[%s]", position.name, params_str),
})
new_data.range = nil
@@ -27,74 +72,6 @@ local function add_test_instances(positions, test_params)
end
end
local socket_path = ""
local socket_future = nio.control.future()
local socket_start = nio.control.future()
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @param callback uv.read_start.callback: One or more messages to send
-- @return string: Concatenated response from server
local function get_socket_path(cmd, messages, callback)
-- 1. Run the command and capture its output (socket path)
if socket_path == "" and not socket_start.is_set() then
local stdout = assert(vim.uv.new_pipe(true))
local stderr = assert(vim.uv.new_pipe(true))
local stdin = assert(vim.uv.new_pipe(true))
socket_start.set()
local handle
handle, _ = vim.uv.spawn(cmd[1], {
stdio = { stdin, stdout, stderr },
detached = true,
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
}, function(code, signal)
vim.uv.close(stdout)
vim.uv.close(stderr)
vim.uv.close(stdin)
vim.uv.close(handle)
end)
vim.uv.read_start(stdout, function(err, s)
if s ~= nil then
socket_path = string.gsub(s, "\n$", "")
socket_future.set()
end
end)
vim.uv.read_start(stderr, function(err, s)
if err ~= nil then
vim.print(err, s)
end
end)
end
socket_future.wait()
-- 2. Connect to the unix socket
local client = assert(vim.uv.new_pipe(false))
client:connect(socket_path, function(err)
if err ~= nil then
vim.print("Error", err)
end
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
vim.uv.write(client, msg .. "\n", function(err_write)
client:read_start(function(err_read, data)
if err_read ~= nil then
vim.print(err_read, data)
else
callback(err_read, data)
end
client:close()
end)
end)
end
end)
end
---@async
---@param path string
---@return boolean
@@ -122,36 +99,28 @@ end
---@param positions neotest.Tree
---@param root string
local function discover_params(python, script, path, positions, root)
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", "-s", path })
logger.debug("Running test instance discovery:", cmd)
local test_params = {}
get_socket_path(cmd, path, function(err, data)
if err ~= nil then
vim.print(err, data)
return
end
if data == nil then
return
end
local data = talk_unix(cmd, path)
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
for line in vim.gsplit(data, "\n", true) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
end
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
end
end
end
return add_test_instances(positions, test_params)
end)
end
return test_params
end
---@async
@@ -163,7 +132,8 @@ end
---@param root string
function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then
discover_params(python, script, path, positions, root)
local test_params = discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
end
end

View File

@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST:
from .pytest_ import PytestNeotestAdapter
from .pytest import PytestNeotestAdapter
return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST:
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]):
if "--pytest-collect" in argv:
argv.remove("--pytest-collect")
from .pytest_ import collect
from .pytest import collect
collect(argv)
return
if "--pytest-extract-test-name-template" in argv:
argv.remove("--pytest-extract-test-name-template")
from .pytest_ import extract_test_name_template
from .pytest import extract_test_name_template
extract_test_name_template(argv)
return

View File

@@ -1,69 +1,30 @@
import asyncio
import atexit
from collections import deque
from collections.abc import Iterable
import hashlib
import itertools
import logging
import os
import select
import signal
import socket
import sys
from _pytest.mark.structures import ParameterSet
from _pytest.python import IdMaker
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, Generator, Self, cast
from typing import Any, Callable, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver:
def __init__(self, s: socket.socket) -> None:
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
def get_tests(paths: Iterable[str]) -> list[str]:
test_names: list[str] = []
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError(path)
raise ModuleNotFoundError
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
@@ -75,58 +36,32 @@ def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
test_names.append(f"{test.__module__}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids")
argnames = mark.args[0]
argvalues = mark.args[1]
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage]
argnames,
argvalues,
None,
None, # pyright: ignore[reportArgumentType]
None, # pyright: ignore[reportArgumentType]
)
if ids is None:
idfn = None
ids_ = None
elif callable(ids):
idfn = ids
ids_ = None
else:
idfn = None
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
None, # pyright: ignore[reportArgumentType]
ids,
parametersets,
test.__name__, # pyright: ignore[reportArgumentType]
)
ids = mark.kwargs.get("ids", [])
for id_, params in itertools.zip_longest(ids, mark.args[1]):
id_ = getattr(params, "id", None) or id_ or repr(params)
test_names.append(f"{test.__module__}::{test.__name__}[{id_}]")
id_maker = IdMaker(
argnames,
parametersets,
idfn,
ids_,
None,
nodeid=None,
func_name=test.__name__,
)
yield from (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
for id_ in id_maker.make_unique_parameterset_ids()
)
return test_names
def _close_socket(path: Path) -> None:
if path.exists():
os.unlink(path)
exit(0)
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
):
tests = get_tests(paths)
if tests:
print("\n".join(tests))
if not socket_mode:
return
async def serve_socket():
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
@@ -138,43 +73,29 @@ async def serve_socket():
print(python_socket_path)
return
atexit.register(lambda: _close_socket(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
child_pid = os.fork()
if child_pid != 0:
return
asyncio.run(serve_socket())
atexit.register(lambda: os.unlink(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: os.unlink(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: os.unlink(python_socket_path))
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(python_socket_path.absolute().as_posix())
sock.listen()
print(python_socket_path)
while ar := sock.accept():
con, _ = ar
try:
paths = con.recv(1000).decode().splitlines()
tests = get_tests(paths)
con.send("\n".join(tests).encode())
except Exception:
logging.exception("Error")
if __name__ == "__main__":
@@ -183,7 +104,6 @@ if __name__ == "__main__":
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument(
"paths",
nargs="*",
@@ -196,5 +116,4 @@ if __name__ == "__main__":
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
no_fork=args.no_fork,
)

View File

@@ -2,7 +2,7 @@ from io import StringIO
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from . import params_getter
import params_getter
import pytest
from _pytest._code.code import ExceptionRepr
@@ -220,8 +220,4 @@ def extract_test_name_template(args):
def collect(args):
params_getter.main(
[],
socket_mode=True,
no_fork=True,
)
params_getter.main([], socket_mode=True)