feat(pytest): use socket instead of a shitton of processes
This commit is contained in:
200
neotest_python/params_getter.py
Normal file
200
neotest_python/params_getter.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
import hashlib
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
from _pytest.mark.structures import ParameterSet
|
||||
from _pytest.python import IdMaker
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
import argparse
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Self, cast
|
||||
|
||||
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
|
||||
|
||||
|
||||
class LineReceiver:
|
||||
def __init__(self, s: socket.socket) -> None:
|
||||
self._s = s
|
||||
self._s.setblocking(False)
|
||||
self._tmp_buffer = ""
|
||||
self._tmp_lines: deque[str] = deque()
|
||||
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
if self._tmp_lines:
|
||||
return self._tmp_lines.popleft()
|
||||
|
||||
ready, _, _ = select.select([self._s], [], [], 0.1)
|
||||
if ready:
|
||||
msg = self._s.recv(256).decode()
|
||||
else:
|
||||
msg = ""
|
||||
lines = msg.splitlines()
|
||||
|
||||
if not lines:
|
||||
raise StopIteration
|
||||
|
||||
lines[0] = self._tmp_buffer + lines[0]
|
||||
if not msg.endswith("\n"):
|
||||
self._tmp_buffer = lines[-1]
|
||||
else:
|
||||
self._tmp_buffer = ""
|
||||
self._tmp_lines.extend(lines)
|
||||
|
||||
return self._tmp_lines.popleft()
|
||||
|
||||
|
||||
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
|
||||
root = Path(os.curdir).absolute()
|
||||
if Path(os.curdir).absolute().as_posix() not in sys.path:
|
||||
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
|
||||
|
||||
for path in (Path(path) for path in paths):
|
||||
spec = importlib.util.spec_from_file_location(path.name, path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ModuleNotFoundError(path)
|
||||
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
|
||||
mod,
|
||||
predicate=lambda member: inspect.isfunction(member)
|
||||
and member.__name__.startswith("test_"),
|
||||
)
|
||||
for _, test in tests:
|
||||
if not (marks := getattr(test, "pytestmark", None)):
|
||||
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
|
||||
continue
|
||||
|
||||
for mark in cast(Iterable[pytest.Mark], marks):
|
||||
if mark.name == "parametrize":
|
||||
ids = mark.kwargs.get("ids")
|
||||
argnames = mark.args[0]
|
||||
argvalues = mark.args[1]
|
||||
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage]
|
||||
argnames,
|
||||
argvalues,
|
||||
None,
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
if ids is None:
|
||||
idfn = None
|
||||
ids_ = None
|
||||
elif callable(ids):
|
||||
idfn = ids
|
||||
ids_ = None
|
||||
else:
|
||||
idfn = None
|
||||
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
ids,
|
||||
parametersets,
|
||||
test.__name__, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
|
||||
id_maker = IdMaker(
|
||||
argnames,
|
||||
parametersets,
|
||||
idfn,
|
||||
ids_,
|
||||
None,
|
||||
nodeid=None,
|
||||
func_name=test.__name__,
|
||||
)
|
||||
yield from (
|
||||
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
|
||||
for id_ in id_maker.make_unique_parameterset_ids()
|
||||
)
|
||||
|
||||
|
||||
def _close_socket(path: Path) -> None:
|
||||
if path.exists():
|
||||
os.unlink(path)
|
||||
exit(0)
|
||||
|
||||
|
||||
async def serve_socket():
|
||||
if not SOCKET_ROOT_DIR.exists():
|
||||
SOCKET_ROOT_DIR.mkdir()
|
||||
|
||||
python_socket_path = (
|
||||
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
|
||||
)
|
||||
|
||||
if python_socket_path.exists():
|
||||
print(python_socket_path)
|
||||
return
|
||||
atexit.register(lambda: _close_socket(python_socket_path))
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
|
||||
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
|
||||
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
|
||||
|
||||
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
path = await reader.readline()
|
||||
tests = "\n".join(
|
||||
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
|
||||
)
|
||||
writer.write(f"{tests}\n".encode())
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_unix_server(
|
||||
handle_client, python_socket_path.absolute().as_posix()
|
||||
)
|
||||
|
||||
async with server:
|
||||
print(python_socket_path)
|
||||
sys.stdout.close()
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
def main(
|
||||
paths: Iterable[str],
|
||||
collect_only: bool = True,
|
||||
quiet: bool = True,
|
||||
verbosity: int = 0,
|
||||
socket_mode: bool = False,
|
||||
no_fork: bool = True,
|
||||
):
|
||||
tests = list(get_tests(paths))
|
||||
if tests:
|
||||
print("\n".join(tests))
|
||||
if not socket_mode:
|
||||
return
|
||||
asyncio.run(serve_socket())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
|
||||
parser.add_argument("-q", dest="quiet", action="store_true")
|
||||
parser.add_argument("--verbosity", dest="verbosity", default=0)
|
||||
parser.add_argument("-s", dest="socket_mode", action="store_true")
|
||||
parser.add_argument("-f", dest="no_fork", action="store_true")
|
||||
parser.add_argument(
|
||||
"paths",
|
||||
nargs="*",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
paths: list[str] = args.paths
|
||||
main(
|
||||
paths=paths,
|
||||
quiet=args.quiet,
|
||||
collect_only=args.collect_only,
|
||||
verbosity=args.verbosity,
|
||||
socket_mode=args.socket_mode,
|
||||
no_fork=args.no_fork,
|
||||
)
|
Reference in New Issue
Block a user