import asyncio import atexit from collections import deque from collections.abc import Iterable import hashlib import itertools import logging import os import select import signal import socket import sys import time import pytest import importlib import importlib.util import argparse import inspect from pathlib import Path from typing import Any, Callable, Generator, Self, cast SOCKET_ROOT_DIR = Path("/tmp/neotest-python") class LineReceiver: def __init__(self, s: socket.socket) -> None: self._s = s self._s.setblocking(False) self._tmp_buffer = "" self._tmp_lines: deque[str] = deque() def __iter__(self) -> Self: return self def __next__(self) -> str: if self._tmp_lines: return self._tmp_lines.popleft() ready, _, _ = select.select([self._s], [], [], 0.1) if ready: msg = self._s.recv(256).decode() else: msg = "" lines = msg.splitlines() if not lines: raise StopIteration lines[0] = self._tmp_buffer + lines[0] if not msg.endswith("\n"): self._tmp_buffer = lines[-1] else: self._tmp_buffer = "" self._tmp_lines.extend(lines) return self._tmp_lines.popleft() def get_tests(paths: Iterable[str]) -> Generator[str, None, None]: root = Path(os.curdir).absolute() if Path(os.curdir).absolute().as_posix() not in sys.path: sys.path.insert(0, Path(os.curdir).absolute().as_posix()) print(sys.path, file=sys.stderr) for path in (Path(path) for path in paths): spec = importlib.util.spec_from_file_location(path.name, path) if spec is None or spec.loader is None: raise ModuleNotFoundError(path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers( mod, predicate=lambda member: inspect.isfunction(member) and member.__name__.startswith("test_"), ) for _, test in tests: if not (marks := getattr(test, "pytestmark", None)): yield (f"{path.relative_to(root).as_posix()}::{test.__name__}") continue for mark in cast(Iterable[pytest.Mark], marks): if mark.name == "parametrize": ids = mark.kwargs.get("ids", []) for id_, params in itertools.zip_longest(ids, mark.args[1]): id_ = ( getattr(params, "id", None) or id_ or "-".join(repr(param) for param in params) ) yield ( f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]" ) def _close_socket(path: Path) -> None: if path.exists(): os.unlink(path) exit(0) async def serve_socket(): if not SOCKET_ROOT_DIR.exists(): SOCKET_ROOT_DIR.mkdir() python_socket_path = ( SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex() ) if python_socket_path.exists(): print(python_socket_path) return atexit.register(lambda: _close_socket(python_socket_path)) signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path)) signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path)) # signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path)) async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): path = await reader.readline() tests = "\n".join( [test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast] ) writer.write(f"{tests}\n".encode()) writer.close() server = await asyncio.start_unix_server( handle_client, python_socket_path.absolute().as_posix() ) async with server: print(python_socket_path) sys.stdout.close() await server.serve_forever() def main( paths: Iterable[str], collect_only: bool = True, quiet: bool = True, verbosity: int = 0, socket_mode: bool = False, no_fork: bool = True, ): tests = list(get_tests(paths)) if tests: print("\n".join(tests)) if not socket_mode: return asyncio.run(serve_socket()) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--collect-only", dest="collect_only", action="store_true") parser.add_argument("-q", dest="quiet", action="store_true") parser.add_argument("--verbosity", dest="verbosity", default=0) parser.add_argument("-s", dest="socket_mode", action="store_true") parser.add_argument("-f", dest="no_fork", action="store_true") parser.add_argument( "paths", nargs="*", ) args = parser.parse_args() paths: list[str] = args.paths main( paths=paths, quiet=args.quiet, collect_only=args.collect_only, verbosity=args.verbosity, socket_mode=args.socket_mode, no_fork=args.no_fork, )