diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index 6d4a3d3..319b45c 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -110,17 +110,6 @@ end ---@return string local function scan_test_function_pattern(runner, config, python_command) local test_function_pattern = "^test" - if runner == "pytest" and config.pytest_discovery then - local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" }) - local _, data = lib.process.run(cmd, { stdout = true, stderr = true }) - - for line in vim.gsplit(data.stdout, "\n", true) do - if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then - local pytest_option = vim.json.decode(line) - test_function_pattern = pytest_option.python_functions - end - end - end return test_function_pattern end @@ -130,7 +119,8 @@ end ---@return string M.treesitter_queries = function(runner, config, python_command) local test_function_pattern = scan_test_function_pattern(runner, config, python_command) - return string.format([[ + return string.format( + [[ ;; Match undecorated functions ((function_definition name: (identifier) @test.name) @@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command) @namespace.definition (#not-has-parent? @namespace.definition decorated_definition) ) - ]], test_function_pattern, test_function_pattern) + ]], + test_function_pattern, + test_function_pattern + ) end M.get_root = @@ -192,7 +185,7 @@ function M.get_runner(python_path) then return vim_test_runner end - local runner = M.module_exists("pytest", python_path) and "pytest" + local runner = M.module_exists("pytest_", python_path) and "pytest" or M.module_exists("django", python_path) and "django" or "unittest" stored_runners[command_str] = runner diff --git a/lua/neotest-python/pytest.lua b/lua/neotest-python/pytest.lua index d462158..16aeabf 100644 --- a/lua/neotest-python/pytest.lua +++ b/lua/neotest-python/pytest.lua @@ -3,6 +3,52 @@ local logger = require("neotest.logging") local M = {} +local unix = require("socket.unix") +local socket_path = "" + +--- Run a command, connect to the UNIX socket it prints, send messages, return response +-- @param cmd string: Command to run (should output socket path) +-- @param messages table|string: One or more messages to send +-- @return string: Concatenated response from server +local function talk_unix(cmd, messages) + -- 1. Run the command and capture its output (socket path) + if not socket_path then + local handle = assert(io.popen(cmd, "r")) + socket_path = handle:read("*l") + handle:close() + assert(socket_path, "Command did not return a socket path") + end + + -- 2. Connect to the unix socket + local client = assert(unix()) + assert(client:connect(socket_path)) + + -- 3. Send message(s) + if type(messages) == "string" then + messages = { messages } + end + for _, msg in ipairs(messages) do + assert(client:send(msg .. "\n")) + end + + -- 4. Read response until EOF or timeout + client:settimeout(1) + local chunks = {} + while true do + local data, err = client:receive("*l") + if not data then + if err ~= "timeout" and err ~= "closed" then + error("Socket receive error: " .. tostring(err)) + end + break + end + table.insert(chunks, data) + end + + client:close() + return table.concat(chunks, "\n") +end + ---@async ---Add test instances for path in root to positions ---@param positions neotest.Tree @@ -53,20 +99,13 @@ end ---@param positions neotest.Tree ---@param root string local function discover_params(python, script, path, positions, root) - local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path }) + local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", "-s", path }) logger.debug("Running test instance discovery:", cmd) local test_params = {} - local res, data = lib.process.run(cmd, { stdout = true, stderr = true }) - if res ~= 0 then - logger.warn("Pytest discovery failed") - if data.stderr then - logger.debug(data.stderr) - end - return {} - end + local data = talk_unix(cmd, path) - for line in vim.gsplit(data.stdout, "\n", true) do + for line in vim.gsplit(data, "\n", true) do local param_index = string.find(line, "[", nil, true) if param_index then local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1) diff --git a/neotest_python/params_getter.py b/neotest_python/params_getter.py new file mode 100644 index 0000000..b3dfb70 --- /dev/null +++ b/neotest_python/params_getter.py @@ -0,0 +1,119 @@ +import atexit +from collections.abc import Iterable +import hashlib +import itertools +import logging +import os +import signal +import socket +import sys +import pytest +import importlib +import importlib.util +import argparse +import inspect +from pathlib import Path +from typing import Any, Callable, cast + +SOCKET_ROOT_DIR = Path("/tmp/neotest-python") + + +def get_tests(paths: Iterable[str]) -> list[str]: + test_names: list[str] = [] + + for path in (Path(path) for path in paths): + spec = importlib.util.spec_from_file_location(path.name, path) + if spec is None or spec.loader is None: + raise ModuleNotFoundError + + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers( + mod, + predicate=lambda member: inspect.isfunction(member) + and member.__name__.startswith("test_"), + ) + for _, test in tests: + if not (marks := getattr(test, "pytestmark", None)): + test_names.append(f"{test.__module__}::{test.__name__}") + continue + + for mark in cast(Iterable[pytest.Mark], marks): + if mark.name == "parametrize": + ids = mark.kwargs.get("ids", []) + for id_, params in itertools.zip_longest(ids, mark.args[1]): + id_ = getattr(params, "id", None) or id_ or repr(params) + test_names.append(f"{test.__module__}::{test.__name__}[{id_}]") + + return test_names + + +def main( + paths: Iterable[str], + collect_only: bool = True, + quiet: bool = True, + verbosity: int = 0, + socket_mode: bool = False, +): + tests = get_tests(paths) + if tests: + print("\n".join(tests)) + if not socket_mode: + return + + if not SOCKET_ROOT_DIR.exists(): + SOCKET_ROOT_DIR.mkdir() + + python_socket_path = ( + SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex() + ) + + if python_socket_path.exists(): + print(python_socket_path) + return + + child_pid = os.fork() + if child_pid != 0: + return + + atexit.register(lambda: os.unlink(python_socket_path)) + + signal.signal(signal.SIGTERM, lambda x, _: os.unlink(python_socket_path)) + signal.signal(signal.SIGINT, lambda x, _: os.unlink(python_socket_path)) + + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(python_socket_path.absolute().as_posix()) + sock.listen() + print(python_socket_path) + while ar := sock.accept(): + con, _ = ar + try: + paths = con.recv(1000).decode().splitlines() + tests = get_tests(paths) + con.send("\n".join(tests).encode()) + + except Exception: + logging.exception("Error") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--collect-only", dest="collect_only", action="store_true") + parser.add_argument("-q", dest="quiet", action="store_true") + parser.add_argument("--verbosity", dest="verbosity", default=0) + parser.add_argument("-s", dest="socket_mode", action="store_true") + parser.add_argument( + "paths", + nargs="*", + ) + args = parser.parse_args() + paths: list[str] = args.paths + main( + paths=paths, + quiet=args.quiet, + collect_only=args.collect_only, + verbosity=args.verbosity, + socket_mode=args.socket_mode, + ) diff --git a/neotest_python/pytest.py b/neotest_python/pytest_.py similarity index 98% rename from neotest_python/pytest.py rename to neotest_python/pytest_.py index e12beb2..0d80d82 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest_.py @@ -2,6 +2,7 @@ from io import StringIO import json from pathlib import Path from typing import Callable, Dict, List, Optional, Union +import params_getter import pytest from _pytest._code.code import ExceptionRepr @@ -214,8 +215,9 @@ class TestNameTemplateExtractor: def extract_test_name_template(args): - pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) + # pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor]) + pass def collect(args): - pytest.main(["--collect-only", "--verbosity=0", "-q"] + args) + params_getter.main([], socket_mode=True) diff --git a/pyproject.toml b/pyproject.toml index 780b5ab..fe88029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,12 @@ +[project] +name = "neotest-python" +version = "0.0.1" +dependencies = [ + "pytest>=8.4.1", +] [tools.black] -line-length=120 +line-length = 120 [tool.isort] profile = "black" @@ -8,8 +14,7 @@ multi_line_output = 3 [tool.pytest.ini_options] filterwarnings = [ - "error", - "ignore::pytest.PytestCollectionWarning", - "ignore:::pynvim[.*]" + "error", + "ignore::pytest.PytestCollectionWarning", + "ignore:::pynvim[.*]", ] - diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..5213b9b --- /dev/null +++ b/uv.lock @@ -0,0 +1,75 @@ +version = 1 +revision = 2 +requires-python = ">=3.13" + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, +] + +[[package]] +name = "neotest-python" +version = "0.0.1" +source = { virtual = "." } +dependencies = [ + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [{ name = "pytest", specifier = ">=8.4.1" }] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" }, +]