feat(pytest): use socket instead of a shitton of processes

This commit is contained in:
Itai Bohadana
2025-08-26 15:48:59 +03:00
parent ed9b4d794b
commit cd32c2afde
8 changed files with 374 additions and 49 deletions

View File

@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST:
from .pytest import PytestNeotestAdapter
from .pytest_ import PytestNeotestAdapter
return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST:
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]):
if "--pytest-collect" in argv:
argv.remove("--pytest-collect")
from .pytest import collect
from .pytest_ import collect
collect(argv)
return
if "--pytest-extract-test-name-template" in argv:
argv.remove("--pytest-extract-test-name-template")
from .pytest import extract_test_name_template
from .pytest_ import extract_test_name_template
extract_test_name_template(argv)
return

View File

@@ -0,0 +1,174 @@
import asyncio
import atexit
from collections import deque
from collections.abc import Iterable
import hashlib
import itertools
import logging
import os
import select
import signal
import socket
import sys
import time
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, Generator, Self, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver:
def __init__(self, s: socket.socket) -> None:
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
print(sys.path, file=sys.stderr)
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError(path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
mod,
predicate=lambda member: inspect.isfunction(member)
and member.__name__.startswith("test_"),
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids", [])
for id_, params in itertools.zip_longest(ids, mark.args[1]):
id_ = (
getattr(params, "id", None)
or id_
or "-".join(repr(param) for param in params)
)
yield (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
)
def _close_socket(path: Path) -> None:
if path.exists():
os.unlink(path)
exit(0)
async def serve_socket():
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
python_socket_path = (
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
)
if python_socket_path.exists():
print(python_socket_path)
return
atexit.register(lambda: _close_socket(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
return
asyncio.run(serve_socket())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument(
"paths",
nargs="*",
)
args = parser.parse_args()
paths: list[str] = args.paths
main(
paths=paths,
quiet=args.quiet,
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
no_fork=args.no_fork,
)

View File

@@ -2,6 +2,7 @@ from io import StringIO
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
from . import params_getter
import pytest
from _pytest._code.code import ExceptionRepr
@@ -214,8 +215,13 @@ class TestNameTemplateExtractor:
def extract_test_name_template(args):
pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
# pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
pass
def collect(args):
pytest.main(["--collect-only", "--verbosity=0", "-q"] + args)
params_getter.main(
[],
socket_mode=True,
no_fork=True,
)