From 9f6fbd6e04ab3e14d39798744f57bb225009b6f6 Mon Sep 17 00:00:00 2001 From: Itai Bohadana Date: Tue, 26 Aug 2025 15:48:59 +0300 Subject: [PATCH] feat(pytest): use socket instead of a shitton of processes --- lua/neotest-python/adapter.lua | 10 +- lua/neotest-python/base.lua | 23 +-- lua/neotest-python/pytest.lua | 112 ++++++++++--- neotest_python/__init__.py | 6 +- neotest_python/params_getter.py | 200 +++++++++++++++++++++++ neotest_python/{pytest.py => pytest_.py} | 10 +- pyproject.toml | 15 +- uv.lock | 75 +++++++++ 8 files changed, 401 insertions(+), 50 deletions(-) create mode 100644 neotest_python/params_getter.py rename neotest_python/{pytest.py => pytest_.py} (97%) create mode 100644 uv.lock diff --git a/lua/neotest-python/adapter.lua b/lua/neotest-python/adapter.lua index d0b5999..2000875 100644 --- a/lua/neotest-python/adapter.lua +++ b/lua/neotest-python/adapter.lua @@ -65,9 +65,13 @@ return function(config) local python_command = config.get_python_command(root) local runner = config.get_runner(python_command) - local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), { - require_namespaces = runner == "unittest", - }) + local positions = lib.treesitter.parse_positions( + path, + base.treesitter_queries(runner, config, python_command), + { + require_namespaces = runner == "unittest", + } + ) if runner == "pytest" and config.pytest_discovery then pytest.augment_positions(python_command, base.get_script_path(), path, positions, root) diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index 6d4a3d3..0ac3ddd 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -96,7 +96,7 @@ end function M.get_script_path() local paths = vim.api.nvim_get_runtime_file("neotest.py", true) for _, path in ipairs(paths) do - if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then + if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then return path end end @@ -110,17 +110,6 @@ end ---@return string local function scan_test_function_pattern(runner, config, python_command) local test_function_pattern = "^test" - if runner == "pytest" and config.pytest_discovery then - local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" }) - local _, data = lib.process.run(cmd, { stdout = true, stderr = true }) - - for line in vim.gsplit(data.stdout, "\n", true) do - if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then - local pytest_option = vim.json.decode(line) - test_function_pattern = pytest_option.python_functions - end - end - end return test_function_pattern end @@ -130,7 +119,8 @@ end ---@return string M.treesitter_queries = function(runner, config, python_command) local test_function_pattern = scan_test_function_pattern(runner, config, python_command) - return string.format([[ + return string.format( + [[ ;; Match undecorated functions ((function_definition name: (identifier) @test.name) @@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command) @namespace.definition (#not-has-parent? @namespace.definition decorated_definition) ) - ]], test_function_pattern, test_function_pattern) + ]], + test_function_pattern, + test_function_pattern + ) end M.get_root = @@ -192,7 +185,7 @@ function M.get_runner(python_path) then return vim_test_runner end - local runner = M.module_exists("pytest", python_path) and "pytest" + local runner = M.module_exists("pytest_", python_path) and "pytest" or M.module_exists("django", python_path) and "django" or "unittest" stored_runners[command_str] = runner diff --git a/lua/neotest-python/pytest.lua b/lua/neotest-python/pytest.lua index d462158..91fbaa7 100644 --- a/lua/neotest-python/pytest.lua +++ b/lua/neotest-python/pytest.lua @@ -1,4 +1,5 @@ local lib = require("neotest.lib") +local nio = require("nio") local logger = require("neotest.logging") local M = {} @@ -15,7 +16,7 @@ local function add_test_instances(positions, test_params) for _, params_str in ipairs(pos_params) do local new_data = vim.tbl_extend("force", position, { id = string.format("%s[%s]", position.id, params_str), - name = string.format("%s[%s]", position.name, params_str), + name = string.format("%s", params_str), }) new_data.range = nil @@ -26,6 +27,73 @@ local function add_test_instances(positions, test_params) end end +local socket_path = "" +local socket_future = nio.control.future() +local socket_start = nio.control.future() +--- Run a command, connect to the UNIX socket it prints, send messages, return response +-- @param cmd string: Command to run (should output socket path) +-- @param messages table|string: One or more messages to send +-- @param callback uv.read_start.callback: One or more messages to send +-- @return string: Concatenated response from server +local function get_socket_path(cmd, messages, callback) + -- 1. Run the command and capture its output (socket path) + local stdout = assert(vim.uv.new_pipe()) + local stderr = assert(vim.uv.new_pipe()) + local stdin = assert(vim.uv.new_pipe()) + if socket_path == "" and not socket_start.is_set() then + socket_start.set() + local handle + handle, _ = vim.uv.spawn(cmd[1], { + stdio = { stdin, stdout, stderr }, + detached = false, + args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] }, + }, function(code, signal) + vim.uv.close(stdout) + vim.uv.close(stderr) + vim.uv.close(stdin) + vim.uv.close(handle) + end) + + vim.uv.read_start(stdout, function(err, s) + if s ~= nil then + socket_path = string.gsub(s, "\n$", "") + socket_future.set() + end + end) + + vim.uv.read_start(stderr, function(err, s) + if err ~= nil then + vim.print(err, s) + end + end) + end + socket_future.wait() + + -- 2. Connect to the unix socket + local client = assert(vim.uv.new_pipe(false)) + + client:connect(socket_path, function(err) + if err ~= nil then + vim.print("Error", err) + end + -- 3. Send message(s) + if type(messages) == "string" then + messages = { messages } + end + for _, msg in ipairs(messages) do + vim.uv.write(client, msg .. "\n", function(err_write) + client:read_start(function(err_read, data) + if err_read ~= nil then + vim.print(err_read, data) + else + callback(err_read, data) + end + end) + end) + end + end) +end + ---@async ---@param path string ---@return boolean @@ -57,31 +125,32 @@ local function discover_params(python, script, path, positions, root) logger.debug("Running test instance discovery:", cmd) local test_params = {} - local res, data = lib.process.run(cmd, { stdout = true, stderr = true }) - if res ~= 0 then - logger.warn("Pytest discovery failed") - if data.stderr then - logger.debug(data.stderr) + get_socket_path(cmd, path, function(err, data) + if err ~= nil then + vim.print(err, data) + return + end + if data == nil then + return end - return {} - end - for line in vim.gsplit(data.stdout, "\n", true) do - local param_index = string.find(line, "[", nil, true) - if param_index then - local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) - local param_id = string.sub(line, param_index + 1, #line - 1) + for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do + local param_index = string.find(line, "[", nil, true) + if param_index then + local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) + local param_id = string.sub(line, param_index + 1, #line - 1) - if positions:get_key(test_id) then - if not test_params[test_id] then - test_params[test_id] = { param_id } - else - table.insert(test_params[test_id], param_id) + if positions:get_key(test_id) then + if not test_params[test_id] then + test_params[test_id] = { param_id } + else + table.insert(test_params[test_id], param_id) + end end end end - end - return test_params + return add_test_instances(positions, test_params) + end) end ---@async @@ -93,8 +162,7 @@ end ---@param root string function M.augment_positions(python, script, path, positions, root) if has_parametrize(path) then - local test_params = discover_params(python, script, path, positions, root) - add_test_instances(positions, test_params) + discover_params(python, script, path, positions, root) end end diff --git a/neotest_python/__init__.py b/neotest_python/__init__.py index f137834..2719f93 100644 --- a/neotest_python/__init__.py +++ b/neotest_python/__init__.py @@ -14,7 +14,7 @@ class TestRunner(str, Enum): def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter: if runner == TestRunner.PYTEST: - from .pytest import PytestNeotestAdapter + from .pytest_ import PytestNeotestAdapter return PytestNeotestAdapter(emit_parameterized_ids) elif runner == TestRunner.UNITTEST: @@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*") def main(argv: List[str]): if "--pytest-collect" in argv: argv.remove("--pytest-collect") - from .pytest import collect + from .pytest_ import collect collect(argv) return if "--pytest-extract-test-name-template" in argv: argv.remove("--pytest-extract-test-name-template") - from .pytest import extract_test_name_template + from .pytest_ import extract_test_name_template extract_test_name_template(argv) return diff --git a/neotest_python/params_getter.py b/neotest_python/params_getter.py new file mode 100644 index 0000000..8f0ba13 --- /dev/null +++ b/neotest_python/params_getter.py @@ -0,0 +1,200 @@ +import asyncio +import atexit +from collections import deque +from collections.abc import Iterable +import hashlib +import os +import select +import signal +import socket +import sys +from _pytest.mark.structures import ParameterSet +from _pytest.python import IdMaker +import pytest +import importlib +import importlib.util +import argparse +import inspect +from pathlib import Path +from typing import Any, Callable, Generator, Self, cast + +SOCKET_ROOT_DIR = Path("/tmp/neotest-python") + + +class LineReceiver: + def __init__(self, s: socket.socket) -> None: + self._s = s + self._s.setblocking(False) + self._tmp_buffer = "" + self._tmp_lines: deque[str] = deque() + + def __iter__(self) -> Self: + return self + + def __next__(self) -> str: + if self._tmp_lines: + return self._tmp_lines.popleft() + + ready, _, _ = select.select([self._s], [], [], 0.1) + if ready: + msg = self._s.recv(256).decode() + else: + msg = "" + lines = msg.splitlines() + + if not lines: + raise StopIteration + + lines[0] = self._tmp_buffer + lines[0] + if not msg.endswith("\n"): + self._tmp_buffer = lines[-1] + else: + self._tmp_buffer = "" + self._tmp_lines.extend(lines) + + return self._tmp_lines.popleft() + + +def get_tests(paths: Iterable[str]) -> Generator[str, None, None]: + root = Path(os.curdir).absolute() + if Path(os.curdir).absolute().as_posix() not in sys.path: + sys.path.insert(0, Path(os.curdir).absolute().as_posix()) + + for path in (Path(path) for path in paths): + spec = importlib.util.spec_from_file_location(path.name, path) + if spec is None or spec.loader is None: + raise ModuleNotFoundError(path) + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers( + mod, + predicate=lambda member: inspect.isfunction(member) + and member.__name__.startswith("test_"), + ) + for _, test in tests: + if not (marks := getattr(test, "pytestmark", None)): + yield (f"{path.relative_to(root).as_posix()}::{test.__name__}") + continue + + for mark in cast(Iterable[pytest.Mark], marks): + if mark.name == "parametrize": + ids = mark.kwargs.get("ids") + argnames = mark.args[0] + argvalues = mark.args[1] + argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage] + argnames, + argvalues, + None, + None, # pyright: ignore[reportArgumentType] + None, # pyright: ignore[reportArgumentType] + ) + if ids is None: + idfn = None + ids_ = None + elif callable(ids): + idfn = ids + ids_ = None + else: + idfn = None + ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage] + None, # pyright: ignore[reportArgumentType] + ids, + parametersets, + test.__name__, # pyright: ignore[reportArgumentType] + ) + + id_maker = IdMaker( + argnames, + parametersets, + idfn, + ids_, + None, + nodeid=None, + func_name=test.__name__, + ) + yield from ( + f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]" + for id_ in id_maker.make_unique_parameterset_ids() + ) + + +def _close_socket(path: Path) -> None: + if path.exists(): + os.unlink(path) + exit(0) + + +async def serve_socket(): + if not SOCKET_ROOT_DIR.exists(): + SOCKET_ROOT_DIR.mkdir() + + python_socket_path = ( + SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex() + ) + + if python_socket_path.exists(): + print(python_socket_path) + return + atexit.register(lambda: _close_socket(python_socket_path)) + + signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path)) + signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path)) + # signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path)) + + async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + path = await reader.readline() + tests = "\n".join( + [test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast] + ) + writer.write(f"{tests}\n".encode()) + writer.close() + + server = await asyncio.start_unix_server( + handle_client, python_socket_path.absolute().as_posix() + ) + + async with server: + print(python_socket_path) + sys.stdout.close() + await server.serve_forever() + + +def main( + paths: Iterable[str], + collect_only: bool = True, + quiet: bool = True, + verbosity: int = 0, + socket_mode: bool = False, + no_fork: bool = True, +): + tests = list(get_tests(paths)) + if tests: + print("\n".join(tests)) + if not socket_mode: + return + asyncio.run(serve_socket()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--collect-only", dest="collect_only", action="store_true") + parser.add_argument("-q", dest="quiet", action="store_true") + parser.add_argument("--verbosity", dest="verbosity", default=0) + parser.add_argument("-s", dest="socket_mode", action="store_true") + parser.add_argument("-f", dest="no_fork", action="store_true") + parser.add_argument( + "paths", + nargs="*", + ) + args = parser.parse_args() + paths: list[str] = args.paths + main( + paths=paths, + quiet=args.quiet, + collect_only=args.collect_only, + verbosity=args.verbosity, + socket_mode=args.socket_mode, + no_fork=args.no_fork, + ) diff --git a/neotest_python/pytest.py b/neotest_python/pytest_.py similarity index 97% rename from neotest_python/pytest.py rename to neotest_python/pytest_.py index e12beb2..b5d55f8 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest_.py @@ -2,6 +2,7 @@ from io import StringIO import json from pathlib import Path from typing import Callable, Dict, List, Optional, Union +from . import params_getter import pytest from _pytest._code.code import ExceptionRepr @@ -214,8 +215,13 @@ class TestNameTemplateExtractor: def extract_test_name_template(args): - pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) + # pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) + pass def collect(args): - pytest.main(["--collect-only", "--verbosity=0", "-q"] + args) + params_getter.main( + [], + socket_mode=True, + no_fork=True, + ) diff --git a/pyproject.toml b/pyproject.toml index 780b5ab..fe88029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,12 @@ +[project] +name = "neotest-python" +version = "0.0.1" +dependencies = [ + "pytest>=8.4.1", +] [tools.black] -line-length=120 +line-length = 120 [tool.isort] profile = "black" @@ -8,8 +14,7 @@ multi_line_output = 3 [tool.pytest.ini_options] filterwarnings = [ - "error", - "ignore::pytest.PytestCollectionWarning", - "ignore:::pynvim[.*]" + "error", + "ignore::pytest.PytestCollectionWarning", + "ignore:::pynvim[.*]", ] - diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..5213b9b --- /dev/null +++ b/uv.lock @@ -0,0 +1,75 @@ +version = 1 +revision = 2 +requires-python = ">=3.13" + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + +[[package]] +name = "neotest-python" +version = "0.0.1" +source = { virtual = "." } +dependencies = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [{ name = "pytest", specifier = ">=8.4.1" }] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +]