import asyncio import atexit from collections import deque from collections.abc import Iterable import hashlib import os import select import signal import socket import sys from _pytest.mark.structures import ParameterSet from _pytest.python import IdMaker import pytest import importlib import importlib.util import argparse import inspect from pathlib import Path from typing import Any, Callable, Generator, Self, cast SOCKET_ROOT_DIR = Path("/tmp/neotest-python") class LineReceiver: def __init__(self, s: socket.socket) -> None: self._s = s self._s.setblocking(False) self._tmp_buffer = "" self._tmp_lines: deque[str] = deque() def __iter__(self) -> Self: return self def __next__(self) -> str: if self._tmp_lines: return self._tmp_lines.popleft() ready, _, _ = select.select([self._s], [], [], 0.1) if ready: msg = self._s.recv(256).decode() else: msg = "" lines = msg.splitlines() if not lines: raise StopIteration lines[0] = self._tmp_buffer + lines[0] if not msg.endswith("\n"): self._tmp_buffer = lines[-1] else: self._tmp_buffer = "" self._tmp_lines.extend(lines) return self._tmp_lines.popleft() def get_tests(paths: Iterable[str]) -> Generator[str, None, None]: root = Path(os.curdir).absolute() if Path(os.curdir).absolute().as_posix() not in sys.path: sys.path.insert(0, Path(os.curdir).absolute().as_posix()) for path in (Path(path) for path in paths): spec = importlib.util.spec_from_file_location(path.name, path) if spec is None or spec.loader is None: raise ModuleNotFoundError(path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers( mod, predicate=lambda member: inspect.isfunction(member) and member.__name__.startswith("test_"), ) for _, test in tests: if not (marks := getattr(test, "pytestmark", None)): yield (f"{path.relative_to(root).as_posix()}::{test.__name__}") continue for mark in cast(Iterable[pytest.Mark], marks): if mark.name == "parametrize": ids = mark.kwargs.get("ids") argnames = mark.args[0] argvalues = mark.args[1] argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage] argnames, argvalues, None, None, # pyright: ignore[reportArgumentType] None, # pyright: ignore[reportArgumentType] ) if ids is None: idfn = None ids_ = None elif callable(ids): idfn = ids ids_ = None else: idfn = None ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage] None, # pyright: ignore[reportArgumentType] ids, parametersets, test.__name__, # pyright: ignore[reportArgumentType] ) id_maker = IdMaker( argnames, parametersets, idfn, ids_, None, nodeid=None, func_name=test.__name__, ) yield from ( f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]" for id_ in id_maker.make_unique_parameterset_ids() ) def _close_socket(path: Path) -> None: if path.exists(): os.unlink(path) exit(0) async def serve_socket(): if not SOCKET_ROOT_DIR.exists(): SOCKET_ROOT_DIR.mkdir() python_socket_path = ( SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex() ) if python_socket_path.exists(): print(python_socket_path) atexit.register(lambda: _close_socket(python_socket_path)) signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path)) signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path)) # signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path)) async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): path = await reader.readline() tests = "\n".join( [test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast] ) writer.write(f"{tests}\n".encode()) writer.close() server = await asyncio.start_unix_server( handle_client, python_socket_path.absolute().as_posix() ) async with server: print(python_socket_path) sys.stdout.close() await server.serve_forever() def main( paths: Iterable[str], collect_only: bool = True, quiet: bool = True, verbosity: int = 0, socket_mode: bool = False, no_fork: bool = True, ): tests = list(get_tests(paths)) if tests: print("\n".join(tests)) if not socket_mode: return asyncio.run(serve_socket()) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--collect-only", dest="collect_only", action="store_true") parser.add_argument("-q", dest="quiet", action="store_true") parser.add_argument("--verbosity", dest="verbosity", default=0) parser.add_argument("-s", dest="socket_mode", action="store_true") parser.add_argument("-f", dest="no_fork", action="store_true") parser.add_argument( "paths", nargs="*", ) args = parser.parse_args() paths: list[str] = args.paths main( paths=paths, quiet=args.quiet, collect_only=args.collect_only, verbosity=args.verbosity, socket_mode=args.socket_mode, no_fork=args.no_fork, )