feat(pytest): use socket instead of a shitton of processes

This commit is contained in:
Itai Bohadana
2025-08-26 15:48:59 +03:00
parent ed9b4d794b
commit cd32c2afde
8 changed files with 374 additions and 49 deletions

View File

@@ -65,9 +65,13 @@ return function(config)
local python_command = config.get_python_command(root) local python_command = config.get_python_command(root)
local runner = config.get_runner(python_command) local runner = config.get_runner(python_command)
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), { local positions = lib.treesitter.parse_positions(
require_namespaces = runner == "unittest", path,
}) base.treesitter_queries(runner, config, python_command),
{
require_namespaces = runner == "unittest",
}
)
if runner == "pytest" and config.pytest_discovery then if runner == "pytest" and config.pytest_discovery then
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root) pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)

View File

@@ -96,7 +96,7 @@ end
function M.get_script_path() function M.get_script_path()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true) local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
for _, path in ipairs(paths) do for _, path in ipairs(paths) do
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
return path return path
end end
end end
@@ -110,17 +110,6 @@ end
---@return string ---@return string
local function scan_test_function_pattern(runner, config, python_command) local function scan_test_function_pattern(runner, config, python_command)
local test_function_pattern = "^test" local test_function_pattern = "^test"
if runner == "pytest" and config.pytest_discovery then
local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" })
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })
for line in vim.gsplit(data.stdout, "\n", true) do
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
local pytest_option = vim.json.decode(line)
test_function_pattern = pytest_option.python_functions
end
end
end
return test_function_pattern return test_function_pattern
end end
@@ -130,7 +119,8 @@ end
---@return string ---@return string
M.treesitter_queries = function(runner, config, python_command) M.treesitter_queries = function(runner, config, python_command)
local test_function_pattern = scan_test_function_pattern(runner, config, python_command) local test_function_pattern = scan_test_function_pattern(runner, config, python_command)
return string.format([[ return string.format(
[[
;; Match undecorated functions ;; Match undecorated functions
((function_definition ((function_definition
name: (identifier) @test.name) name: (identifier) @test.name)
@@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command)
@namespace.definition @namespace.definition
(#not-has-parent? @namespace.definition decorated_definition) (#not-has-parent? @namespace.definition decorated_definition)
) )
]], test_function_pattern, test_function_pattern) ]],
test_function_pattern,
test_function_pattern
)
end end
M.get_root = M.get_root =
@@ -192,7 +185,7 @@ function M.get_runner(python_path)
then then
return vim_test_runner return vim_test_runner
end end
local runner = M.module_exists("pytest", python_path) and "pytest" local runner = M.module_exists("pytest_", python_path) and "pytest"
or M.module_exists("django", python_path) and "django" or M.module_exists("django", python_path) and "django"
or "unittest" or "unittest"
stored_runners[command_str] = runner stored_runners[command_str] = runner

View File

@@ -1,4 +1,5 @@
local lib = require("neotest.lib") local lib = require("neotest.lib")
local nio = require("nio")
local logger = require("neotest.logging") local logger = require("neotest.logging")
local M = {} local M = {}
@@ -26,6 +27,73 @@ local function add_test_instances(positions, test_params)
end end
end end
local socket_path = ""
local socket_future = nio.control.future()
local socket_start = nio.control.future()
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @param callback uv.read_start.callback: One or more messages to send
-- @return string: Concatenated response from server
local function get_socket_path(cmd, messages, callback)
-- 1. Run the command and capture its output (socket path)
local stdout = assert(vim.uv.new_pipe())
local stderr = assert(vim.uv.new_pipe())
local stdin = assert(vim.uv.new_pipe())
if socket_path == "" and not socket_start.is_set() then
socket_start.set()
local handle
handle, _ = vim.uv.spawn(cmd[1], {
stdio = { stdin, stdout, stderr },
detached = false,
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
}, function(code, signal)
vim.uv.close(stdout)
vim.uv.close(stderr)
vim.uv.close(stdin)
vim.uv.close(handle)
end)
vim.uv.read_start(stdout, function(err, s)
if s ~= nil then
socket_path = string.gsub(s, "\n$", "")
socket_future.set()
end
end)
vim.uv.read_start(stderr, function(err, s)
if err ~= nil then
vim.print(err, s)
end
end)
end
socket_future.wait()
-- 2. Connect to the unix socket
local client = assert(vim.uv.new_pipe(false))
client:connect(socket_path, function(err)
if err ~= nil then
vim.print("Error", err)
end
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
vim.uv.write(client, msg .. "\n", function(err_write)
client:read_start(function(err_read, data)
if err_read ~= nil then
vim.print(err_read, data)
else
callback(err_read, data)
end
end)
end)
end
end)
end
---@async ---@async
---@param path string ---@param path string
---@return boolean ---@return boolean
@@ -57,31 +125,32 @@ local function discover_params(python, script, path, positions, root)
logger.debug("Running test instance discovery:", cmd) logger.debug("Running test instance discovery:", cmd)
local test_params = {} local test_params = {}
local res, data = lib.process.run(cmd, { stdout = true, stderr = true }) get_socket_path(cmd, path, function(err, data)
if res ~= 0 then if err ~= nil then
logger.warn("Pytest discovery failed") vim.print(err, data)
if data.stderr then return
logger.debug(data.stderr) end
if data == nil then
return
end end
return {}
end
for line in vim.gsplit(data.stdout, "\n", true) do for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
local param_index = string.find(line, "[", nil, true) local param_index = string.find(line, "[", nil, true)
if param_index then if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1) local param_id = string.sub(line, param_index + 1, #line - 1)
if positions:get_key(test_id) then if positions:get_key(test_id) then
if not test_params[test_id] then if not test_params[test_id] then
test_params[test_id] = { param_id } test_params[test_id] = { param_id }
else else
table.insert(test_params[test_id], param_id) table.insert(test_params[test_id], param_id)
end
end end
end end
end end
end return add_test_instances(positions, test_params)
return test_params end)
end end
---@async ---@async
@@ -93,8 +162,7 @@ end
---@param root string ---@param root string
function M.augment_positions(python, script, path, positions, root) function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then if has_parametrize(path) then
local test_params = discover_params(python, script, path, positions, root) discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
end end
end end

View File

@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter: def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST: if runner == TestRunner.PYTEST:
from .pytest import PytestNeotestAdapter from .pytest_ import PytestNeotestAdapter
return PytestNeotestAdapter(emit_parameterized_ids) return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST: elif runner == TestRunner.UNITTEST:
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]): def main(argv: List[str]):
if "--pytest-collect" in argv: if "--pytest-collect" in argv:
argv.remove("--pytest-collect") argv.remove("--pytest-collect")
from .pytest import collect from .pytest_ import collect
collect(argv) collect(argv)
return return
if "--pytest-extract-test-name-template" in argv: if "--pytest-extract-test-name-template" in argv:
argv.remove("--pytest-extract-test-name-template") argv.remove("--pytest-extract-test-name-template")
from .pytest import extract_test_name_template from .pytest_ import extract_test_name_template
extract_test_name_template(argv) extract_test_name_template(argv)
return return

View File

@@ -0,0 +1,174 @@
import asyncio
import atexit
from collections import deque
from collections.abc import Iterable
import hashlib
import itertools
import logging
import os
import select
import signal
import socket
import sys
import time
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, Generator, Self, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver:
def __init__(self, s: socket.socket) -> None:
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
print(sys.path, file=sys.stderr)
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError(path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
mod,
predicate=lambda member: inspect.isfunction(member)
and member.__name__.startswith("test_"),
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids", [])
for id_, params in itertools.zip_longest(ids, mark.args[1]):
id_ = (
getattr(params, "id", None)
or id_
or "-".join(repr(param) for param in params)
)
yield (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
)
def _close_socket(path: Path) -> None:
if path.exists():
os.unlink(path)
exit(0)
async def serve_socket():
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
python_socket_path = (
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
)
if python_socket_path.exists():
print(python_socket_path)
return
atexit.register(lambda: _close_socket(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
return
asyncio.run(serve_socket())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument(
"paths",
nargs="*",
)
args = parser.parse_args()
paths: list[str] = args.paths
main(
paths=paths,
quiet=args.quiet,
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
no_fork=args.no_fork,
)

View File

@@ -2,6 +2,7 @@ from io import StringIO
import json import json
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
from . import params_getter
import pytest import pytest
from _pytest._code.code import ExceptionRepr from _pytest._code.code import ExceptionRepr
@@ -214,8 +215,13 @@ class TestNameTemplateExtractor:
def extract_test_name_template(args): def extract_test_name_template(args):
pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) # pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
pass
def collect(args): def collect(args):
pytest.main(["--collect-only", "--verbosity=0", "-q"] + args) params_getter.main(
[],
socket_mode=True,
no_fork=True,
)

View File

@@ -1,6 +1,12 @@
[project]
name = "neotest-python"
version = "0.0.1"
dependencies = [
"pytest>=8.4.1",
]
[tools.black] [tools.black]
line-length=120 line-length = 120
[tool.isort] [tool.isort]
profile = "black" profile = "black"
@@ -8,8 +14,7 @@ multi_line_output = 3
[tool.pytest.ini_options] [tool.pytest.ini_options]
filterwarnings = [ filterwarnings = [
"error", "error",
"ignore::pytest.PytestCollectionWarning", "ignore::pytest.PytestCollectionWarning",
"ignore:::pynvim[.*]" "ignore:::pynvim[.*]",
] ]

75
uv.lock generated Normal file
View File

@@ -0,0 +1,75 @@
version = 1
revision = 2
requires-python = ">=3.13"
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "neotest-python"
version = "0.0.1"
source = { virtual = "." }
dependencies = [
{ name = "pytest" },
]
[package.metadata]
requires-dist = [{ name = "pytest", specifier = ">=8.4.1" }]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "8.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
]