Compare commits
1 Commits
master
...
ff75e11bca
Author | SHA1 | Date | |
---|---|---|---|
|
ff75e11bca |
@@ -65,13 +65,9 @@ return function(config)
|
||||
local python_command = config.get_python_command(root)
|
||||
local runner = config.get_runner(python_command)
|
||||
|
||||
local positions = lib.treesitter.parse_positions(
|
||||
path,
|
||||
base.treesitter_queries(runner, config, python_command),
|
||||
{
|
||||
require_namespaces = runner == "unittest",
|
||||
}
|
||||
)
|
||||
local positions = lib.treesitter.parse_positions(path, base.treesitter_queries(runner, config, python_command), {
|
||||
require_namespaces = runner == "unittest",
|
||||
})
|
||||
|
||||
if runner == "pytest" and config.pytest_discovery then
|
||||
pytest.augment_positions(python_command, base.get_script_path(), path, positions, root)
|
||||
|
@@ -96,7 +96,7 @@ end
|
||||
function M.get_script_path()
|
||||
local paths = vim.api.nvim_get_runtime_file("neotest.py", true)
|
||||
for _, path in ipairs(paths) do
|
||||
if vim.endswith(path, ("python%sneotest.py"):format(lib.files.sep)) then
|
||||
if vim.endswith(path, ("neotest-python%sneotest.py"):format(lib.files.sep)) then
|
||||
return path
|
||||
end
|
||||
end
|
||||
|
@@ -1,9 +1,54 @@
|
||||
local lib = require("neotest.lib")
|
||||
local nio = require("nio")
|
||||
local logger = require("neotest.logging")
|
||||
|
||||
local M = {}
|
||||
|
||||
local unix = require("socket.unix")
|
||||
local socket_path = ""
|
||||
|
||||
--- Run a command, connect to the UNIX socket it prints, send messages, return response
|
||||
-- @param cmd string: Command to run (should output socket path)
|
||||
-- @param messages table|string: One or more messages to send
|
||||
-- @return string: Concatenated response from server
|
||||
local function talk_unix(cmd, messages)
|
||||
-- 1. Run the command and capture its output (socket path)
|
||||
if not socket_path then
|
||||
local handle = assert(io.popen(cmd, "r"))
|
||||
socket_path = handle:read("*l")
|
||||
handle:close()
|
||||
assert(socket_path, "Command did not return a socket path")
|
||||
end
|
||||
|
||||
-- 2. Connect to the unix socket
|
||||
local client = assert(unix())
|
||||
assert(client:connect(socket_path))
|
||||
|
||||
-- 3. Send message(s)
|
||||
if type(messages) == "string" then
|
||||
messages = { messages }
|
||||
end
|
||||
for _, msg in ipairs(messages) do
|
||||
assert(client:send(msg .. "\n"))
|
||||
end
|
||||
|
||||
-- 4. Read response until EOF or timeout
|
||||
client:settimeout(1)
|
||||
local chunks = {}
|
||||
while true do
|
||||
local data, err = client:receive("*l")
|
||||
if not data then
|
||||
if err ~= "timeout" and err ~= "closed" then
|
||||
error("Socket receive error: " .. tostring(err))
|
||||
end
|
||||
break
|
||||
end
|
||||
table.insert(chunks, data)
|
||||
end
|
||||
|
||||
client:close()
|
||||
return table.concat(chunks, "\n")
|
||||
end
|
||||
|
||||
---@async
|
||||
---Add test instances for path in root to positions
|
||||
---@param positions neotest.Tree
|
||||
@@ -16,7 +61,7 @@ local function add_test_instances(positions, test_params)
|
||||
for _, params_str in ipairs(pos_params) do
|
||||
local new_data = vim.tbl_extend("force", position, {
|
||||
id = string.format("%s[%s]", position.id, params_str),
|
||||
name = string.format("%s", params_str),
|
||||
name = string.format("%s[%s]", position.name, params_str),
|
||||
})
|
||||
new_data.range = nil
|
||||
|
||||
@@ -27,74 +72,6 @@ local function add_test_instances(positions, test_params)
|
||||
end
|
||||
end
|
||||
|
||||
local socket_path = ""
|
||||
local socket_future = nio.control.future()
|
||||
local socket_start = nio.control.future()
|
||||
--- Run a command, connect to the UNIX socket it prints, send messages, return response
|
||||
-- @param cmd string: Command to run (should output socket path)
|
||||
-- @param messages table|string: One or more messages to send
|
||||
-- @param callback uv.read_start.callback: One or more messages to send
|
||||
-- @return string: Concatenated response from server
|
||||
local function get_socket_path(cmd, messages, callback)
|
||||
-- 1. Run the command and capture its output (socket path)
|
||||
if socket_path == "" and not socket_start.is_set() then
|
||||
local stdout = assert(vim.uv.new_pipe(true))
|
||||
local stderr = assert(vim.uv.new_pipe(true))
|
||||
local stdin = assert(vim.uv.new_pipe(true))
|
||||
socket_start.set()
|
||||
local handle
|
||||
handle, _ = vim.uv.spawn(cmd[1], {
|
||||
stdio = { stdin, stdout, stderr },
|
||||
detached = true,
|
||||
args = #cmd > 1 and vim.list_slice(cmd, 2, #cmd) or { cmd[1] },
|
||||
}, function(code, signal)
|
||||
vim.uv.close(stdout)
|
||||
vim.uv.close(stderr)
|
||||
vim.uv.close(stdin)
|
||||
vim.uv.close(handle)
|
||||
end)
|
||||
|
||||
vim.uv.read_start(stdout, function(err, s)
|
||||
if s ~= nil then
|
||||
socket_path = string.gsub(s, "\n$", "")
|
||||
socket_future.set()
|
||||
end
|
||||
end)
|
||||
|
||||
vim.uv.read_start(stderr, function(err, s)
|
||||
if err ~= nil then
|
||||
vim.print(err, s)
|
||||
end
|
||||
end)
|
||||
end
|
||||
socket_future.wait()
|
||||
|
||||
-- 2. Connect to the unix socket
|
||||
local client = assert(vim.uv.new_pipe(false))
|
||||
|
||||
client:connect(socket_path, function(err)
|
||||
if err ~= nil then
|
||||
vim.print("Error", err)
|
||||
end
|
||||
-- 3. Send message(s)
|
||||
if type(messages) == "string" then
|
||||
messages = { messages }
|
||||
end
|
||||
for _, msg in ipairs(messages) do
|
||||
vim.uv.write(client, msg .. "\n", function(err_write)
|
||||
client:read_start(function(err_read, data)
|
||||
if err_read ~= nil then
|
||||
vim.print(err_read, data)
|
||||
else
|
||||
callback(err_read, data)
|
||||
end
|
||||
client:close()
|
||||
end)
|
||||
end)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
---@async
|
||||
---@param path string
|
||||
---@return boolean
|
||||
@@ -122,36 +99,28 @@ end
|
||||
---@param positions neotest.Tree
|
||||
---@param root string
|
||||
local function discover_params(python, script, path, positions, root)
|
||||
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
|
||||
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", "-s", path })
|
||||
logger.debug("Running test instance discovery:", cmd)
|
||||
|
||||
local test_params = {}
|
||||
get_socket_path(cmd, path, function(err, data)
|
||||
if err ~= nil then
|
||||
vim.print(err, data)
|
||||
return
|
||||
end
|
||||
if data == nil then
|
||||
return
|
||||
end
|
||||
local data = talk_unix(cmd, path)
|
||||
|
||||
for line in vim.gsplit(data, "\n", { trimempty = true, plain = true }) do
|
||||
local param_index = string.find(line, "[", nil, true)
|
||||
if param_index then
|
||||
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||
for line in vim.gsplit(data, "\n", true) do
|
||||
local param_index = string.find(line, "[", nil, true)
|
||||
if param_index then
|
||||
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)
|
||||
local param_id = string.sub(line, param_index + 1, #line - 1)
|
||||
|
||||
if positions:get_key(test_id) then
|
||||
if not test_params[test_id] then
|
||||
test_params[test_id] = { param_id }
|
||||
else
|
||||
table.insert(test_params[test_id], param_id)
|
||||
end
|
||||
if positions:get_key(test_id) then
|
||||
if not test_params[test_id] then
|
||||
test_params[test_id] = { param_id }
|
||||
else
|
||||
table.insert(test_params[test_id], param_id)
|
||||
end
|
||||
end
|
||||
end
|
||||
return add_test_instances(positions, test_params)
|
||||
end)
|
||||
end
|
||||
return test_params
|
||||
end
|
||||
|
||||
---@async
|
||||
@@ -163,7 +132,8 @@ end
|
||||
---@param root string
|
||||
function M.augment_positions(python, script, path, positions, root)
|
||||
if has_parametrize(path) then
|
||||
discover_params(python, script, path, positions, root)
|
||||
local test_params = discover_params(python, script, path, positions, root)
|
||||
add_test_instances(positions, test_params)
|
||||
end
|
||||
end
|
||||
|
||||
|
@@ -14,7 +14,7 @@ class TestRunner(str, Enum):
|
||||
|
||||
def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
|
||||
if runner == TestRunner.PYTEST:
|
||||
from .pytest_ import PytestNeotestAdapter
|
||||
from .pytest import PytestNeotestAdapter
|
||||
|
||||
return PytestNeotestAdapter(emit_parameterized_ids)
|
||||
elif runner == TestRunner.UNITTEST:
|
||||
@@ -53,14 +53,14 @@ parser.add_argument("args", nargs="*")
|
||||
def main(argv: List[str]):
|
||||
if "--pytest-collect" in argv:
|
||||
argv.remove("--pytest-collect")
|
||||
from .pytest_ import collect
|
||||
from .pytest import collect
|
||||
|
||||
collect(argv)
|
||||
return
|
||||
|
||||
if "--pytest-extract-test-name-template" in argv:
|
||||
argv.remove("--pytest-extract-test-name-template")
|
||||
from .pytest_ import extract_test_name_template
|
||||
from .pytest import extract_test_name_template
|
||||
|
||||
extract_test_name_template(argv)
|
||||
return
|
||||
|
@@ -1,69 +1,30 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
from _pytest.mark.structures import ParameterSet
|
||||
from _pytest.python import IdMaker
|
||||
import pytest
|
||||
import importlib
|
||||
import importlib.util
|
||||
import argparse
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generator, Self, cast
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
|
||||
|
||||
|
||||
class LineReceiver:
|
||||
def __init__(self, s: socket.socket) -> None:
|
||||
self._s = s
|
||||
self._s.setblocking(False)
|
||||
self._tmp_buffer = ""
|
||||
self._tmp_lines: deque[str] = deque()
|
||||
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __next__(self) -> str:
|
||||
if self._tmp_lines:
|
||||
return self._tmp_lines.popleft()
|
||||
|
||||
ready, _, _ = select.select([self._s], [], [], 0.1)
|
||||
if ready:
|
||||
msg = self._s.recv(256).decode()
|
||||
else:
|
||||
msg = ""
|
||||
lines = msg.splitlines()
|
||||
|
||||
if not lines:
|
||||
raise StopIteration
|
||||
|
||||
lines[0] = self._tmp_buffer + lines[0]
|
||||
if not msg.endswith("\n"):
|
||||
self._tmp_buffer = lines[-1]
|
||||
else:
|
||||
self._tmp_buffer = ""
|
||||
self._tmp_lines.extend(lines)
|
||||
|
||||
return self._tmp_lines.popleft()
|
||||
|
||||
|
||||
def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
|
||||
root = Path(os.curdir).absolute()
|
||||
if Path(os.curdir).absolute().as_posix() not in sys.path:
|
||||
sys.path.insert(0, Path(os.curdir).absolute().as_posix())
|
||||
def get_tests(paths: Iterable[str]) -> list[str]:
|
||||
test_names: list[str] = []
|
||||
|
||||
for path in (Path(path) for path in paths):
|
||||
spec = importlib.util.spec_from_file_location(path.name, path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ModuleNotFoundError(path)
|
||||
raise ModuleNotFoundError
|
||||
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
@@ -75,58 +36,32 @@ def get_tests(paths: Iterable[str]) -> Generator[str, None, None]:
|
||||
)
|
||||
for _, test in tests:
|
||||
if not (marks := getattr(test, "pytestmark", None)):
|
||||
yield (f"{path.relative_to(root).as_posix()}::{test.__name__}")
|
||||
test_names.append(f"{test.__module__}::{test.__name__}")
|
||||
continue
|
||||
|
||||
for mark in cast(Iterable[pytest.Mark], marks):
|
||||
if mark.name == "parametrize":
|
||||
ids = mark.kwargs.get("ids")
|
||||
argnames = mark.args[0]
|
||||
argvalues = mark.args[1]
|
||||
argnames, parametersets = ParameterSet._for_parametrize( # pyright: ignore[reportUnknownMemberType, reportPrivateUsage]
|
||||
argnames,
|
||||
argvalues,
|
||||
None,
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
if ids is None:
|
||||
idfn = None
|
||||
ids_ = None
|
||||
elif callable(ids):
|
||||
idfn = ids
|
||||
ids_ = None
|
||||
else:
|
||||
idfn = None
|
||||
ids_ = pytest.Metafunc._validate_ids( # pyright: ignore[reportPrivateUsage]
|
||||
None, # pyright: ignore[reportArgumentType]
|
||||
ids,
|
||||
parametersets,
|
||||
test.__name__, # pyright: ignore[reportArgumentType]
|
||||
)
|
||||
ids = mark.kwargs.get("ids", [])
|
||||
for id_, params in itertools.zip_longest(ids, mark.args[1]):
|
||||
id_ = getattr(params, "id", None) or id_ or repr(params)
|
||||
test_names.append(f"{test.__module__}::{test.__name__}[{id_}]")
|
||||
|
||||
id_maker = IdMaker(
|
||||
argnames,
|
||||
parametersets,
|
||||
idfn,
|
||||
ids_,
|
||||
None,
|
||||
nodeid=None,
|
||||
func_name=test.__name__,
|
||||
)
|
||||
yield from (
|
||||
f"{path.relative_to(root).as_posix()}::{test.__name__}[{id_}]"
|
||||
for id_ in id_maker.make_unique_parameterset_ids()
|
||||
)
|
||||
return test_names
|
||||
|
||||
|
||||
def _close_socket(path: Path) -> None:
|
||||
if path.exists():
|
||||
os.unlink(path)
|
||||
exit(0)
|
||||
def main(
|
||||
paths: Iterable[str],
|
||||
collect_only: bool = True,
|
||||
quiet: bool = True,
|
||||
verbosity: int = 0,
|
||||
socket_mode: bool = False,
|
||||
):
|
||||
tests = get_tests(paths)
|
||||
if tests:
|
||||
print("\n".join(tests))
|
||||
if not socket_mode:
|
||||
return
|
||||
|
||||
|
||||
async def serve_socket():
|
||||
if not SOCKET_ROOT_DIR.exists():
|
||||
SOCKET_ROOT_DIR.mkdir()
|
||||
|
||||
@@ -138,43 +73,29 @@ async def serve_socket():
|
||||
print(python_socket_path)
|
||||
return
|
||||
|
||||
atexit.register(lambda: _close_socket(python_socket_path))
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda x, _: _close_socket(python_socket_path))
|
||||
signal.signal(signal.SIGINT, lambda x, _: _close_socket(python_socket_path))
|
||||
|
||||
async def handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
path = await reader.readline()
|
||||
tests = "\n".join(
|
||||
[test for test in get_tests([cast(bytes, path).decode().strip()])] # pyright: ignore[reportUnnecessaryCast]
|
||||
)
|
||||
writer.write(f"{tests}\n".encode())
|
||||
writer.close()
|
||||
|
||||
server = await asyncio.start_unix_server(
|
||||
handle_client, python_socket_path.absolute().as_posix()
|
||||
)
|
||||
|
||||
async with server:
|
||||
print(python_socket_path)
|
||||
sys.stdout.close()
|
||||
await server.serve_forever()
|
||||
|
||||
|
||||
def main(
|
||||
paths: Iterable[str],
|
||||
collect_only: bool = True,
|
||||
quiet: bool = True,
|
||||
verbosity: int = 0,
|
||||
socket_mode: bool = False,
|
||||
no_fork: bool = True,
|
||||
):
|
||||
tests = list(get_tests(paths))
|
||||
if tests:
|
||||
print("\n".join(tests))
|
||||
if not socket_mode:
|
||||
child_pid = os.fork()
|
||||
if child_pid != 0:
|
||||
return
|
||||
asyncio.run(serve_socket())
|
||||
|
||||
atexit.register(lambda: os.unlink(python_socket_path))
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda x, _: os.unlink(python_socket_path))
|
||||
signal.signal(signal.SIGINT, lambda x, _: os.unlink(python_socket_path))
|
||||
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(python_socket_path.absolute().as_posix())
|
||||
sock.listen()
|
||||
print(python_socket_path)
|
||||
while ar := sock.accept():
|
||||
con, _ = ar
|
||||
try:
|
||||
paths = con.recv(1000).decode().splitlines()
|
||||
tests = get_tests(paths)
|
||||
con.send("\n".join(tests).encode())
|
||||
|
||||
except Exception:
|
||||
logging.exception("Error")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -183,7 +104,6 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-q", dest="quiet", action="store_true")
|
||||
parser.add_argument("--verbosity", dest="verbosity", default=0)
|
||||
parser.add_argument("-s", dest="socket_mode", action="store_true")
|
||||
parser.add_argument("-f", dest="no_fork", action="store_true")
|
||||
parser.add_argument(
|
||||
"paths",
|
||||
nargs="*",
|
||||
@@ -196,5 +116,4 @@ if __name__ == "__main__":
|
||||
collect_only=args.collect_only,
|
||||
verbosity=args.verbosity,
|
||||
socket_mode=args.socket_mode,
|
||||
no_fork=args.no_fork,
|
||||
)
|
||||
|
@@ -2,7 +2,7 @@ from io import StringIO
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
from . import params_getter
|
||||
import params_getter
|
||||
|
||||
import pytest
|
||||
from _pytest._code.code import ExceptionRepr
|
||||
@@ -220,8 +220,4 @@ def extract_test_name_template(args):
|
||||
|
||||
|
||||
def collect(args):
|
||||
params_getter.main(
|
||||
[],
|
||||
socket_mode=True,
|
||||
no_fork=True,
|
||||
)
|
||||
params_getter.main([], socket_mode=True)
|
||||
|
Reference in New Issue
Block a user