feat(pytest): use socket instead of a shitton of processes

This commit is contained in:
Itai Bohadana
2025-08-26 15:48:59 +03:00
parent ed9b4d794b
commit 9f6fbd6e04
8 changed files with 401 additions and 50 deletions

View File

@@ -0,0 +1,200 @@
import asyncio
import atexit
from collections import deque
from collections.abc import Iterable
import hashlib
import os
import select
import signal
import socket
import sys
from _pytest.mark.structures import ParameterSet
from _pytest.python import IdMaker
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, Generator, Self, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
class LineReceiver:
def __init__(self, s: socket.socket) -> None:
self._s = s
self._s.setblocking(False)
self._tmp_buffer = ""
self._tmp_lines: deque[str] = deque()
def __iter__(self) -> Self:
return self
def __next__(self) -> str:
if self._tmp_lines:
return self._tmp_lines.popleft()
ready, _, _ = select.select([self._s], [], [], 0.1)
if ready:
msg = self._s.recv(256).decode()
else:
msg = ""
lines = msg.splitlines()
if not lines:
raise StopIteration
lines[0] = self._tmp_buffer + lines[0]
if not msg.endswith("\n"):
self._tmp_buffer = lines[-1]
else:
self._tmp_buffer = ""
self._tmp_lines.extend(lines)
return self._tmp_lines.popleft()
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
root = Path(os.curdir).absolute()
if Path(os.curdir).absolute().as_posix() not in sys.path:
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError(path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
mod,
predicate=lambda member: inspect.isfunction(member)
and member.__name__.startswith("test_"),
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids")
argnames = mark.args[0]
argvalues = mark.args[1]
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage]
argnames,
argvalues,
None,
None, # pyright: ignore[reportArgumentType]
None, # pyright: ignore[reportArgumentType]
)
if ids is None:
idfn = None
ids_ = None
elif callable(ids):
idfn = ids
ids_ = None
else:
idfn = None
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
None, # pyright: ignore[reportArgumentType]
ids,
parametersets,
test.__name__, # pyright: ignore[reportArgumentType]
)
id_maker = IdMaker(
argnames,
parametersets,
idfn,
ids_,
None,
nodeid=None,
func_name=test.__name__,
)
yield from (
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
for id_ in id_maker.make_unique_parameterset_ids()
)
def _close_socket(path: Path) -> None:
if path.exists():
os.unlink(path)
exit(0)
async def serve_socket():
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
python_socket_path = (
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
)
if python_socket_path.exists():
print(python_socket_path)
return
atexit.register(lambda: _close_socket(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
path = await reader.readline()
tests = "\n".join(
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
)
writer.write(f"{tests}\n".encode())
writer.close()
server = await asyncio.start_unix_server(
handle_client, python_socket_path.absolute().as_posix()
)
async with server:
print(python_socket_path)
sys.stdout.close()
await server.serve_forever()
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
no_fork: bool = True,
):
tests = list(get_tests(paths))
if tests:
print("\n".join(tests))
if not socket_mode:
return
asyncio.run(serve_socket())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument("-f", dest="no_fork", action="store_true")
parser.add_argument(
"paths",
nargs="*",
)
args = parser.parse_args()
paths: list[str] = args.paths
main(
paths=paths,
quiet=args.quiet,
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
no_fork=args.no_fork,
)