201 lines
6.3 KiB
Python
201 lines
6.3 KiB
Python
import asyncio
|
|
import atexit
|
|
from collections import deque
|
|
from collections.abc import Iterable
|
|
import hashlib
|
|
import os
|
|
import select
|
|
import signal
|
|
import socket
|
|
import sys
|
|
from _pytest.mark.structures import ParameterSet
|
|
from _pytest.python import IdMaker
|
|
import pytest
|
|
import importlib
|
|
import importlib.util
|
|
import argparse
|
|
import inspect
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Generator, Self, cast
|
|
|
|
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
|
|
|
|
|
|
class LineReceiver:
|
|
def __init__(self, s: socket.socket) -> None:
|
|
self._s = s
|
|
self._s.setblocking(False)
|
|
self._tmp_buffer = ""
|
|
self._tmp_lines: deque[str] = deque()
|
|
|
|
def __iter__(self) -> Self:
|
|
return self
|
|
|
|
def __next__(self) -> str:
|
|
if self._tmp_lines:
|
|
return self._tmp_lines.popleft()
|
|
|
|
ready, _, _ = select.select([self._s], [], [], 0.1)
|
|
if ready:
|
|
msg = self._s.recv(256).decode()
|
|
else:
|
|
msg = ""
|
|
lines = msg.splitlines()
|
|
|
|
if not lines:
|
|
raise StopIteration
|
|
|
|
lines[0] = self._tmp_buffer + lines[0]
|
|
if not msg.endswith("\n"):
|
|
self._tmp_buffer = lines[-1]
|
|
else:
|
|
self._tmp_buffer = ""
|
|
self._tmp_lines.extend(lines)
|
|
|
|
return self._tmp_lines.popleft()
|
|
|
|
|
|
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
|
|
root = Path(os.curdir).absolute()
|
|
if Path(os.curdir).absolute().as_posix() not in sys.path:
|
|
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
|
|
|
|
for path in (Path(path) for path in paths):
|
|
spec = importlib.util.spec_from_file_location(path.name, path)
|
|
if spec is None or spec.loader is None:
|
|
raise ModuleNotFoundError(path)
|
|
|
|
mod = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(mod)
|
|
|
|
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
|
|
mod,
|
|
predicate=lambda member: inspect.isfunction(member)
|
|
and member.__name__.startswith("test_"),
|
|
)
|
|
for _, test in tests:
|
|
if not (marks := getattr(test, "pytestmark", None)):
|
|
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
|
|
continue
|
|
|
|
for mark in cast(Iterable[pytest.Mark], marks):
|
|
if mark.name == "parametrize":
|
|
ids = mark.kwargs.get("ids")
|
|
argnames = mark.args[0]
|
|
argvalues = mark.args[1]
|
|
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage]
|
|
argnames,
|
|
argvalues,
|
|
None,
|
|
None, # pyright: ignore[reportArgumentType]
|
|
None, # pyright: ignore[reportArgumentType]
|
|
)
|
|
if ids is None:
|
|
idfn = None
|
|
ids_ = None
|
|
elif callable(ids):
|
|
idfn = ids
|
|
ids_ = None
|
|
else:
|
|
idfn = None
|
|
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
|
|
None, # pyright: ignore[reportArgumentType]
|
|
ids,
|
|
parametersets,
|
|
test.__name__, # pyright: ignore[reportArgumentType]
|
|
)
|
|
|
|
id_maker = IdMaker(
|
|
argnames,
|
|
parametersets,
|
|
idfn,
|
|
ids_,
|
|
None,
|
|
nodeid=None,
|
|
func_name=test.__name__,
|
|
)
|
|
yield from (
|
|
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
|
|
for id_ in id_maker.make_unique_parameterset_ids()
|
|
)
|
|
|
|
|
|
def _close_socket(path: Path) -> None:
|
|
if path.exists():
|
|
os.unlink(path)
|
|
exit(0)
|
|
|
|
|
|
async def serve_socket():
|
|
if not SOCKET_ROOT_DIR.exists():
|
|
SOCKET_ROOT_DIR.mkdir()
|
|
|
|
python_socket_path = (
|
|
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
|
|
)
|
|
|
|
if python_socket_path.exists():
|
|
print(python_socket_path)
|
|
|
|
atexit.register(lambda: _close_socket(python_socket_path))
|
|
|
|
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
|
|
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
|
|
# signal.signal(signal.SIGKILL, lambda x, _: _close_socket(python_socket_path))
|
|
|
|
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
path = await reader.readline()
|
|
tests = "\n".join(
|
|
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
|
|
)
|
|
writer.write(f"{tests}\n".encode())
|
|
writer.close()
|
|
|
|
server = await asyncio.start_unix_server(
|
|
handle_client, python_socket_path.absolute().as_posix()
|
|
)
|
|
|
|
async with server:
|
|
print(python_socket_path)
|
|
sys.stdout.close()
|
|
await server.serve_forever()
|
|
|
|
|
|
def main(
|
|
paths: Iterable[str],
|
|
collect_only: bool = True,
|
|
quiet: bool = True,
|
|
verbosity: int = 0,
|
|
socket_mode: bool = False,
|
|
no_fork: bool = True,
|
|
):
|
|
tests = list(get_tests(paths))
|
|
if tests:
|
|
print("\n".join(tests))
|
|
if not socket_mode:
|
|
return
|
|
asyncio.run(serve_socket())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
|
|
parser.add_argument("-q", dest="quiet", action="store_true")
|
|
parser.add_argument("--verbosity", dest="verbosity", default=0)
|
|
parser.add_argument("-s", dest="socket_mode", action="store_true")
|
|
parser.add_argument("-f", dest="no_fork", action="store_true")
|
|
parser.add_argument(
|
|
"paths",
|
|
nargs="*",
|
|
)
|
|
args = parser.parse_args()
|
|
paths: list[str] = args.paths
|
|
main(
|
|
paths=paths,
|
|
quiet=args.quiet,
|
|
collect_only=args.collect_only,
|
|
verbosity=args.verbosity,
|
|
socket_mode=args.socket_mode,
|
|
no_fork=args.no_fork,
|
|
)
|