175 lines
5.2 KiB
Python
175 lines
5.2 KiB
Python
import asyncio
|
|
import atexit
|
|
from collections import deque
|
|
from collections.abc import Iterable
|
|
import hashlib
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import select
|
|
import signal
|
|
import socket
|
|
import sys
|
|
import time
|
|
import pytest
|
|
import importlib
|
|
import importlib.util
|
|
import argparse
|
|
import inspect
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Generator, Self, cast
|
|
|
|
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
|
|
|
|
|
|
class LineReceiver:
|
|
def __init__(self, s: socket.socket) -> None:
|
|
self._s = s
|
|
self._s.setblocking(False)
|
|
self._tmp_buffer = ""
|
|
self._tmp_lines: deque[str] = deque()
|
|
|
|
def __iter__(self) -> Self:
|
|
return self
|
|
|
|
def __next__(self) -> str:
|
|
if self._tmp_lines:
|
|
return self._tmp_lines.popleft()
|
|
|
|
ready, _, _ = select.select([self._s], [], [], 0.1)
|
|
if ready:
|
|
msg = self._s.recv(256).decode()
|
|
else:
|
|
msg = ""
|
|
lines = msg.splitlines()
|
|
|
|
if not lines:
|
|
raise StopIteration
|
|
|
|
lines[0] = self._tmp_buffer + lines[0]
|
|
if not msg.endswith("\n"):
|
|
self._tmp_buffer = lines[-1]
|
|
else:
|
|
self._tmp_buffer = ""
|
|
self._tmp_lines.extend(lines)
|
|
|
|
return self._tmp_lines.popleft()
|
|
|
|
|
|
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
|
|
root = Path(os.curdir).absolute()
|
|
if Path(os.curdir).absolute().as_posix() not in sys.path:
|
|
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
|
|
print(sys.path, file=sys.stderr)
|
|
|
|
for path in (Path(path) for path in paths):
|
|
spec = importlib.util.spec_from_file_location(path.name, path)
|
|
if spec is None or spec.loader is None:
|
|
raise ModuleNotFoundError(path)
|
|
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
|
|
mod,
|
|
predicate=lambda member: inspect.isfunction(member)
|
|
and member.__name__.startswith("test_"),
|
|
)
|
|
for _, test in tests:
|
|
if not (marks := getattr(test, "pytestmark", None)):
|
|
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
|
|
continue
|
|
|
|
for mark in cast(Iterable[pytest.Mark], marks):
|
|
if mark.name == "parametrize":
|
|
ids = mark.kwargs.get("ids", [])
|
|
for id_, params in itertools.zip_longest(ids, mark.args[1]):
|
|
id_ = (
|
|
getattr(params, "id", None)
|
|
or id_
|
|
or "-".join(repr(param) for param in params)
|
|
)
|
|
yield (
|
|
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
|
|
)
|
|
|
|
|
|
def _close_socket(path: Path) -> None:
|
|
if path.exists():
|
|
os.unlink(path)
|
|
exit(0)
|
|
|
|
|
|
async def serve_socket():
|
|
if not SOCKET_ROOT_DIR.exists():
|
|
SOCKET_ROOT_DIR.mkdir()
|
|
|
|
python_socket_path = (
|
|
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
|
|
)
|
|
|
|
if python_socket_path.exists():
|
|
print(python_socket_path)
|
|
return
|
|
atexit.register(lambda: _close_socket(python_socket_path))
|
|
|
|
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
|
|
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
|
|
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
|
|
|
|
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
path = await reader.readline()
|
|
tests = "\n".join(
|
|
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
|
|
)
|
|
writer.write(f"{tests}\n".encode())
|
|
writer.close()
|
|
|
|
server = await asyncio.start_unix_server(
|
|
handle_client, python_socket_path.absolute().as_posix()
|
|
)
|
|
|
|
async with server:
|
|
print(python_socket_path)
|
|
sys.stdout.close()
|
|
await server.serve_forever()
|
|
|
|
|
|
def main(
|
|
paths: Iterable[str],
|
|
collect_only: bool = True,
|
|
quiet: bool = True,
|
|
verbosity: int = 0,
|
|
socket_mode: bool = False,
|
|
no_fork: bool = True,
|
|
):
|
|
tests = list(get_tests(paths))
|
|
if tests:
|
|
print("\n".join(tests))
|
|
if not socket_mode:
|
|
return
|
|
asyncio.run(serve_socket())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
|
|
parser.add_argument("-q", dest="quiet", action="store_true")
|
|
parser.add_argument("--verbosity", dest="verbosity", default=0)
|
|
parser.add_argument("-s", dest="socket_mode", action="store_true")
|
|
parser.add_argument("-f", dest="no_fork", action="store_true")
|
|
parser.add_argument(
|
|
"paths",
|
|
nargs="*",
|
|
)
|
|
args = parser.parse_args()
|
|
paths: list[str] = args.paths
|
|
main(
|
|
paths=paths,
|
|
quiet=args.quiet,
|
|
collect_only=args.collect_only,
|
|
verbosity=args.verbosity,
|
|
socket_mode=args.socket_mode,
|
|
no_fork=args.no_fork,
|
|
)
|