Compare commits

...

1 Commits

Author SHA1 Message Date
Itai Bohadana
ff75e11bca feat(pytest): use socket instead of a shitton of processes 2025-08-26 15:48:59 +03:00
6 changed files with 264 additions and 31 deletions

View File

@@ -110,17 +110,6 @@ end
---@return string
local function scan_test_function_pattern(runner, config, python_command)
local test_function_pattern = "^test"
if runner == "pytest" and config.pytest_discovery then
local cmd = vim.tbl_flatten({ python_command, M.get_script_path(), "--pytest-extract-test-name-template" })
local _, data = lib.process.run(cmd, { stdout = true, stderr = true })
for line in vim.gsplit(data.stdout, "\n", true) do
if string.sub(line, 1, 1) == "{" and string.find(line, "python_functions") ~= nil then
local pytest_option = vim.json.decode(line)
test_function_pattern = pytest_option.python_functions
end
end
end
return test_function_pattern
end
@@ -130,7 +119,8 @@ end
---@return string
M.treesitter_queries = function(runner, config, python_command)
local test_function_pattern = scan_test_function_pattern(runner, config, python_command)
return string.format([[
return string.format(
[[
;; Match undecorated functions
((function_definition
name: (identifier) @test.name)
@@ -158,7 +148,10 @@ M.treesitter_queries = function(runner, config, python_command)
@namespace.definition
(#not-has-parent? @namespace.definition decorated_definition)
)
]], test_function_pattern, test_function_pattern)
]],
test_function_pattern,
test_function_pattern
)
end
M.get_root =
@@ -192,7 +185,7 @@ function M.get_runner(python_path)
then
return vim_test_runner
end
local runner = M.module_exists("pytest", python_path) and "pytest"
local runner = M.module_exists("pytest_", python_path) and "pytest"
or M.module_exists("django", python_path) and "django"
or "unittest"
stored_runners[command_str] = runner

View File

@@ -3,6 +3,52 @@ local logger = require("neotest.logging")
local M = {}
local unix = require("socket.unix")
local socket_path = ""
--- Run a command, connect to the UNIX socket it prints, send messages, return response
-- @param cmd string: Command to run (should output socket path)
-- @param messages table|string: One or more messages to send
-- @return string: Concatenated response from server
local function talk_unix(cmd, messages)
-- 1. Run the command and capture its output (socket path)
if not socket_path then
local handle = assert(io.popen(cmd, "r"))
socket_path = handle:read("*l")
handle:close()
assert(socket_path, "Command did not return a socket path")
end
-- 2. Connect to the unix socket
local client = assert(unix())
assert(client:connect(socket_path))
-- 3. Send message(s)
if type(messages) == "string" then
messages = { messages }
end
for _, msg in ipairs(messages) do
assert(client:send(msg .. "\n"))
end
-- 4. Read response until EOF or timeout
client:settimeout(1)
local chunks = {}
while true do
local data, err = client:receive("*l")
if not data then
if err ~= "timeout" and err ~= "closed" then
error("Socket receive error: " .. tostring(err))
end
break
end
table.insert(chunks, data)
end
client:close()
return table.concat(chunks, "\n")
end
---@async
---Add test instances for path in root to positions
---@param positions neotest.Tree
@@ -53,20 +99,13 @@ end
---@param positions neotest.Tree
---@param root string
local function discover_params(python, script, path, positions, root)
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", path })
local cmd = vim.tbl_flatten({ python, script, "--pytest-collect", "-s", path })
logger.debug("Running test instance discovery:", cmd)
local test_params = {}
local res, data = lib.process.run(cmd, { stdout = true, stderr = true })
if res ~= 0 then
logger.warn("Pytest discovery failed")
if data.stderr then
logger.debug(data.stderr)
end
return {}
end
local data = talk_unix(cmd, path)
for line in vim.gsplit(data.stdout, "\n", true) do
for line in vim.gsplit(data, "\n", true) do
local param_index = string.find(line, "[", nil, true)
if param_index then
local test_id = root .. lib.files.path.sep .. string.sub(line, 1, param_index - 1)

View File

@@ -0,0 +1,119 @@
import atexit
from collections.abc import Iterable
import hashlib
import itertools
import logging
import os
import signal
import socket
import sys
import pytest
import importlib
import importlib.util
import argparse
import inspect
from pathlib import Path
from typing import Any, Callable, cast
SOCKET_ROOT_DIR = Path("/tmp/neotest-python")
def get_tests(paths: Iterable[str]) -> list[str]:
test_names: list[str] = []
for path in (Path(path) for path in paths):
spec = importlib.util.spec_from_file_location(path.name, path)
if spec is None or spec.loader is None:
raise ModuleNotFoundError
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
tests: Iterable[tuple[str, Callable[..., Any]]] = inspect.getmembers(
mod,
predicate=lambda member: inspect.isfunction(member)
and member.__name__.startswith("test_"),
)
for _, test in tests:
if not (marks := getattr(test, "pytestmark", None)):
test_names.append(f"{test.__module__}::{test.__name__}")
continue
for mark in cast(Iterable[pytest.Mark], marks):
if mark.name == "parametrize":
ids = mark.kwargs.get("ids", [])
for id_, params in itertools.zip_longest(ids, mark.args[1]):
id_ = getattr(params, "id", None) or id_ or repr(params)
test_names.append(f"{test.__module__}::{test.__name__}[{id_}]")
return test_names
def main(
paths: Iterable[str],
collect_only: bool = True,
quiet: bool = True,
verbosity: int = 0,
socket_mode: bool = False,
):
tests = get_tests(paths)
if tests:
print("\n".join(tests))
if not socket_mode:
return
if not SOCKET_ROOT_DIR.exists():
SOCKET_ROOT_DIR.mkdir()
python_socket_path = (
SOCKET_ROOT_DIR / hashlib.sha1(sys.executable.encode()).digest().hex()
)
if python_socket_path.exists():
print(python_socket_path)
return
child_pid = os.fork()
if child_pid != 0:
return
atexit.register(lambda: os.unlink(python_socket_path))
signal.signal(signal.SIGTERM, lambda x, _: os.unlink(python_socket_path))
signal.signal(signal.SIGINT, lambda x, _: os.unlink(python_socket_path))
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(python_socket_path.absolute().as_posix())
sock.listen()
print(python_socket_path)
while ar := sock.accept():
con, _ = ar
try:
paths = con.recv(1000).decode().splitlines()
tests = get_tests(paths)
con.send("\n".join(tests).encode())
except Exception:
logging.exception("Error")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--collect-only", dest="collect_only", action="store_true")
parser.add_argument("-q", dest="quiet", action="store_true")
parser.add_argument("--verbosity", dest="verbosity", default=0)
parser.add_argument("-s", dest="socket_mode", action="store_true")
parser.add_argument(
"paths",
nargs="*",
)
args = parser.parse_args()
paths: list[str] = args.paths
main(
paths=paths,
quiet=args.quiet,
collect_only=args.collect_only,
verbosity=args.verbosity,
socket_mode=args.socket_mode,
)

View File

@@ -2,6 +2,7 @@ from io import StringIO
import json
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
import params_getter
import pytest
from _pytest._code.code import ExceptionRepr
@@ -214,8 +215,9 @@ class TestNameTemplateExtractor:
def extract_test_name_template(args):
pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
# pytest.main(args=["-k", "neotest_none"], plugins=[TestNameTemplateExtractor])
pass
def collect(args):
pytest.main(["--collect-only", "--verbosity=0", "-q"] + args)
params_getter.main([], socket_mode=True)

View File

@@ -1,6 +1,12 @@
[project]
name = "neotest-python"
version = "0.0.1"
dependencies = [
"pytest>=8.4.1",
]
[tools.black]
line-length=120
line-length = 120
[tool.isort]
profile = "black"
@@ -8,8 +14,7 @@ multi_line_output = 3
[tool.pytest.ini_options]
filterwarnings = [
"error",
"ignore::pytest.PytestCollectionWarning",
"ignore:::pynvim[.*]"
"error",
"ignore::pytest.PytestCollectionWarning",
"ignore:::pynvim[.*]",
]

75
uv.lock generated Normal file
View File

@@ -0,0 +1,75 @@
version = 1
revision = 2
requires-python = ">=3.13"
[[package]]
name = "colorama"
version = "0.4.6"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "iniconfig"
version = "2.1.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" },
]
[[package]]
name = "neotest-python"
version = "0.0.1"
source = { virtual = "." }
dependencies = [
{ name = "pytest" },
]
[package.metadata]
requires-dist = [{ name = "pytest", specifier = ">=8.4.1" }]
[[package]]
name = "packaging"
version = "25.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "pygments"
version = "2.19.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" },
]
[[package]]
name = "pytest"
version = "8.4.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714, upload-time = "2025-06-18T05:48:06.109Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474, upload-time = "2025-06-18T05:48:03.955Z" },
]