feat(pytest): use socket instead of a shitton of processes

This commit is contained in:
Itai Bohadana
2025-08-26 15:48:59 +03:00
parent ed9b4d794b
commit 9b6cc8bd62
8 changed files with 371 additions and 50 deletions

View File

@@ -65,9 +65,13 @@ return function(config)
local python_command = config.get_python_command(root)
local runner = config.get_runner(python_command)
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
require_namespaces = runner == "unittest",
})
local positions = lib.treesitter.parse_positions(
path,
base.treesitter_queries(runner, config, python_command),
{
require_namespaces = runner == "unittest",
}
)
if runner == "pytest" and config.pytest_discovery then
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)

View File

@@ -96,7 +96,7 @@ end
function M.get_script_path()
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
for _, path in ipairs(paths) do
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
return path
end
end
@@ -110,17 +110,6 @@ end
---@return string
local function scan_test_function_pattern(runner, config, python_command)
local test_function_pattern = "^test"
if runner == "pytest" and config.pytest_discovery then
local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" })
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })
for line in vim.gsplit(data.stdout, "\n", true) do
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
local pytest_option = vim.json.decode(line)
test_function_pattern = pytest_option.python_functions
end
end
end
return test_function_pattern
end
@@ -130,7 +119,8 @@ end
---@return string
M.treesitter_queries = function(runner, config, python_command)
local test_function_pattern = scan_test_function_pattern(runner, config, python_command)
return string.format([[
return string.format(
[[
;; Match undecorated functions
((function_definition
name: (identifier) @test.name)
@@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command)
@namespace.definition
(#not-has-parent? @namespace.definition decorated_definition)
)
]], test_function_pattern, test_function_pattern)
]],
test_function_pattern,
test_function_pattern
)
end
M.get_root =
@@ -192,7 +185,7 @@ function M.get_runner(python_path)
then
return vim_test_runner
end
local runner = M.module_exists("pytest", python_path) and "pytest"
local runner = M.module_exists("pytest_", python_path) and "pytest"
or M.module_exists("django", python_path) and "django"
or "unittest"
stored_runners[command_str] = runner

View File

@@ -1,4 +1,5 @@
local lib = require("neotest.lib")
local nio = require("nio")
local logger = require("neotest.logging")
local M = {}
@@ -26,6 +27,73 @@ local function add_test_instances(positions, test_params)
end
end
local socket_path = ""
local socket_future = nio.control.future()
local socket_start = nio.control.future()
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @param callback uv.read_start.callback: One or more messages to send
-- @return string: Concatenated response from server
local function get_socket_path(cmd, messages, callback)
-- 1. Run the command and capture its output (socket path)
local stdout = assert(vim.uv.new_pipe())
local stderr = assert(vim.uv.new_pipe())
local stdin = assert(vim.uv.new_pipe())
if socket_path == "" and not socket_start.is_set() then
socket_start.set()
local handle
handle, _ = vim.uv.spawn(cmd[1], {
stdio = { stdin, stdout, stderr },
detached = false,
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
}, function(code, signal)
vim.uv.close(stdout)
vim.uv.close(stderr)
vim.uv.close(stdin)
vim.uv.close(handle)
end)
vim.uv.read_start(stdout, function(err, s)
if s ~= nil then
socket_path = string.gsub(s, "\n$", "")
socket_future.set()
end
end)
vim.uv.read_start(stderr, function(err, s)
if err ~= nil then
vim.print(err, s)
end
end)
end
socket_future.wait()
-- 2. Connect to the unix socket
local client = assert(vim.uv.new_pipe(false))
client:connect(socket_path, function(err)
if err ~= nil then
vim.print("Error", err)
end
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
vim.uv.write(client, msg .. "\n", function(err_write)
client:read_start(function(err_read, data)
if err_read ~= nil then
vim.print(err_read, data)
else
callback(err_read, data)
end
end)
end)
end
end)
end
---@async
---@param path string
---@return boolean
@@ -57,31 +125,28 @@ local function discover_params(python, script, path, positions, root)
logger.debug("Running test instance discovery:", cmd)
local test_params = {}
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
if res ~= 0 then
logger.warn("Pytest discovery failed")
if data.stderr then
logger.debug(data.stderr)
get_socket_path(cmd, path, function(err, data)
if err ~= nil then
vim.print(err, data)
return
end
return {}
end
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
for line in vim.gsplit(data.stdout, "\n", true) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
local param_id = string.sub(line, param_index + 1, #line - 1)
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
if positions:get_key(test_id) then
if not test_params[test_id] then
test_params[test_id] = { param_id }
else
table.insert(test_params[test_id], param_id)
end
end
end
end
end
return test_params
return add_test_instances(positions, test_params)
end)
end
---@async
@@ -93,8 +158,7 @@ end
---@param root string
function M.augment_positions(python, script, path, positions, root)
if has_parametrize(path) then
local test_params = discover_params(python, script, path, positions, root)
add_test_instances(positions, test_params)
discover_params(python, script, path, positions, root)
end
end

View File

@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST:
from .pytest import PytestNeotestAdapter
from .pytest_ import PytestNeotestAdapter
return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST:
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]):
if "--pytest-collect" in argv:
argv.remove("--pytest-collect")
from .pytest import collect
from .pytest_ import collect
collect(argv)
return
if "--pytest-extract-test-name-template" in argv:
argv.remove("--pytest-extract-test-name-template")
from .pytest import extract_test_name_template
from .pytest_ import extract_test_name_template
extract_test_name_template(argv)
return

View File

@@ -0,0 +1,174 @@
import asyncio
import atexit
from collections import deque
from collections.abc import Iterable
import hashlib
import itertools
import logging
import os
import select
import signal
import socket
import sys
import time
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, Generator, Self, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver:
def __init__(self, s: socket.socket) -> None:
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
print(sys.path, file=sys.stderr)
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError(path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
mod,
predicate=lambda member: inspect.isfunction(member)
and member.__name__.startswith("test_"),
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids", [])
for id_, params in itertools.zip_longest(ids, mark.args[1]):
id_ = (
getattr(params, "id", None)
or id_
or "-".join(repr(param) for param in params)
)
yield (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
)
def _close_socket(path: Path) -> None:
if path.exists():
os.unlink(path)
exit(0)
async def serve_socket():
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
python_socket_path = (
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
)
if python_socket_path.exists():
print(python_socket_path)
return
atexit.register(lambda: _close_socket(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
return
asyncio.run(serve_socket())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument(
"paths",
nargs="*",
)
args = parser.parse_args()
paths: list[str] = args.paths
main(
paths=paths,
quiet=args.quiet,
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
no_fork=args.no_fork,
)

View File

@@ -2,6 +2,7 @@ from io import StringIO
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from . import params_getter
import pytest
from _pytest._code.code import ExceptionRepr
@@ -214,8 +215,13 @@ class TestNameTemplateExtractor:
def extract_test_name_template(args):
pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
# pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
pass
def collect(args):
pytest.main(["--collect-only", "--verbosity=0", "-q"] + args)
params_getter.main(
[],
socket_mode=True,
no_fork=True,
)

View File

@@ -1,6 +1,12 @@
[project]
name = "neotest-python"
version = "0.0.1"
dependencies = [
"pytest>=8.4.1",
]
[tools.black]
line-length=120
line-length = 120
[tool.isort]
profile = "black"
@@ -8,8 +14,7 @@ multi_line_output = 3
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore::pytest.PytestCollectionWarning",
"ignore:::pynvim[.*]"
"error",
"ignore::pytest.PytestCollectionWarning",
"ignore:::pynvim[.*]",
]

75
uv.lock generated Normal file
View File

@@ -0,0 +1,75 @@
version = 1
revision = 2
requires-python = ">=3.13"
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "neotest-python"
version = "0.0.1"
source = { virtual = "." }
dependencies = [
{ name = "pytest" },
]
[package.metadata]
requires-dist = [{ name = "pytest", specifier = ">=8.4.1" }]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "8.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
]