fix(pytest): handle parameterized tests without pytest discovery

Only emits position IDs with parameters when pytest discovery is enabled

See #36 and #59
This commit is contained in:
Rónán Carrigan
2023-12-20 18:49:01 +00:00
parent 48bf141103
commit 27a2676aa0
5 changed files with 59 additions and 28 deletions

View File

@@ -59,10 +59,14 @@ local get_runner = function(python_command)
if vim_test_runner == "pyunit" then if vim_test_runner == "pyunit" then
return "unittest" return "unittest"
end end
if vim_test_runner and lib.func_util.index({ "unittest", "pytest", "django" }, vim_test_runner) then if
vim_test_runner and lib.func_util.index({ "unittest", "pytest", "django" }, vim_test_runner)
then
return vim_test_runner return vim_test_runner
end end
local runner = base.module_exists("pytest", python_command) and "pytest" or base.module_exists("django", python_command) and "django" or "unittest" local runner = base.module_exists("pytest", python_command) and "pytest"
or base.module_exists("django", python_command) and "django"
or "unittest"
stored_runners[command_str] = runner stored_runners[command_str] = runner
return runner return runner
end end
@@ -71,7 +75,7 @@ end
local PythonNeotestAdapter = { name = "neotest-python" } local PythonNeotestAdapter = { name = "neotest-python" }
PythonNeotestAdapter.root = PythonNeotestAdapter.root =
lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py") lib.files.match_root_pattern("pyproject.toml", "setup.cfg", "mypy.ini", "pytest.ini", "setup.py")
function PythonNeotestAdapter.is_test_file(file_path) function PythonNeotestAdapter.is_test_file(file_path)
return is_test_file(file_path) return is_test_file(file_path)
@@ -147,9 +151,15 @@ function PythonNeotestAdapter.build_spec(args)
stream_path, stream_path,
"--runner", "--runner",
runner, runner,
"--",
vim.list_extend(get_args(runner, position, args.strategy), args.extra_args or {}),
}) })
if pytest_discover_instances then
table.insert(script_args, "--emit-parameterized-ids")
end
vim.list_extend(script_args, get_args(runner, position, args.strategy))
if args.extra_args then
vim.list_extend(script_args, args.extra_args)
end
if position then if position then
table.insert(script_args, position.id) table.insert(script_args, position.id)
end end

View File

@@ -12,11 +12,11 @@ class TestRunner(str, Enum):
DJANGO = "django" DJANGO = "django"
def get_adapter(runner: TestRunner) -> NeotestAdapter: def get_adapter(runner: TestRunner, emit_parameterized_ids: bool) -> NeotestAdapter:
if runner == TestRunner.PYTEST: if runner == TestRunner.PYTEST:
from .pytest import PytestNeotestAdapter from .pytest import PytestNeotestAdapter
return PytestNeotestAdapter() return PytestNeotestAdapter(emit_parameterized_ids)
elif runner == TestRunner.UNITTEST: elif runner == TestRunner.UNITTEST:
from .unittest import UnittestNeotestAdapter from .unittest import UnittestNeotestAdapter
@@ -42,6 +42,11 @@ parser.add_argument(
required=True, required=True,
help="File to stream result JSON to", help="File to stream result JSON to",
) )
parser.add_argument(
"--emit-parameterized-ids",
action="store_true",
help="Emit parameterized test ids (pytest only)",
)
parser.add_argument("args", nargs="*") parser.add_argument("args", nargs="*")
@@ -49,11 +54,12 @@ def main(argv: List[str]):
if "--pytest-collect" in argv: if "--pytest-collect" in argv:
argv.remove("--pytest-collect") argv.remove("--pytest-collect")
from .pytest import collect from .pytest import collect
collect(argv) collect(argv)
return return
args = parser.parse_args(argv) args = parser.parse_args(argv)
adapter = get_adapter(TestRunner(args.runner)) adapter = get_adapter(TestRunner(args.runner), args.emit_parameterized_ids)
with open(args.stream_file, "w") as stream_file: with open(args.stream_file, "w") as stream_file:

View File

@@ -1,17 +1,19 @@
import inspect import inspect
import subprocess
import os import os
import subprocess
import sys import sys
import traceback import traceback
import unittest import unittest
from argparse import ArgumentParser from argparse import ArgumentParser
from pathlib import Path from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import Any, Tuple, Dict, List from typing import Any, Dict, List, Tuple
from unittest import TestCase from unittest import TestCase
from unittest.runner import TextTestResult from unittest.runner import TextTestResult
from django import setup as django_setup from django import setup as django_setup
from django.test.runner import DiscoverRunner from django.test.runner import DiscoverRunner
from .base import NeotestAdapter, NeotestError, NeotestResultStatus from .base import NeotestAdapter, NeotestError, NeotestResultStatus
@@ -67,11 +69,7 @@ class DjangoNeotestAdapter(CaseUtilsMixin, NeotestAdapter):
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
DiscoverRunner.add_arguments(parser) DiscoverRunner.add_arguments(parser)
parser.add_argument( parser.add_argument("--verbosity", nargs="?", default=2)
"--verbosity",
nargs="?",
default=2
)
parser.add_argument( parser.add_argument(
"--failfast", "--failfast",
action="store_true", action="store_true",

View File

@@ -2,24 +2,32 @@ from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus
import pytest import pytest
from _pytest._code.code import ExceptionRepr from _pytest._code.code import ExceptionRepr
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus
class PytestNeotestAdapter(NeotestAdapter): class PytestNeotestAdapter(NeotestAdapter):
def __init__(self, emit_parameterized_ids: bool):
self.emit_parameterized_ids = emit_parameterized_ids
def run( def run(
self, self,
args: List[str], args: List[str],
stream: Callable[[str, NeotestResult], None], stream: Callable[[str, NeotestResult], None],
) -> Dict[str, NeotestResult]: ) -> Dict[str, NeotestResult]:
result_collector = NeotestResultCollector(self, stream=stream) result_collector = NeotestResultCollector(
pytest.main(args=args, plugins=[ self, stream=stream, emit_parameterized_ids=self.emit_parameterized_ids
result_collector, )
NeotestDebugpyPlugin(), pytest.main(
]) args=args,
plugins=[
result_collector,
NeotestDebugpyPlugin(),
],
)
return result_collector.results return result_collector.results
@@ -28,9 +36,11 @@ class NeotestResultCollector:
self, self,
adapter: PytestNeotestAdapter, adapter: PytestNeotestAdapter,
stream: Callable[[str, NeotestResult], None], stream: Callable[[str, NeotestResult], None],
emit_parameterized_ids: bool,
): ):
self.stream = stream self.stream = stream
self.adapter = adapter self.adapter = adapter
self.emit_parameterized_ids = emit_parameterized_ids
self.pytest_config: Optional["pytest.Config"] = None # type: ignore self.pytest_config: Optional["pytest.Config"] = None # type: ignore
self.results: Dict[str, NeotestResult] = {} self.results: Dict[str, NeotestResult] = {}
@@ -80,7 +90,9 @@ class NeotestResultCollector:
self.pytest_config = config self.pytest_config = config
@pytest.hookimpl(hookwrapper=True) @pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(self, item: "pytest.Item", call: "pytest.CallInfo") -> None: def pytest_runtest_makereport(
self, item: "pytest.Item", call: "pytest.CallInfo"
) -> None:
# pytest generates the report.outcome field in its internal # pytest generates the report.outcome field in its internal
# pytest_runtest_makereport implementation, so call it first. (We don't # pytest_runtest_makereport implementation, so call it first. (We don't
# implement pytest_runtest_logreport because it doesn't have access to # implement pytest_runtest_logreport because it doesn't have access to
@@ -105,8 +117,10 @@ class NeotestResultCollector:
msg_prefix = "" msg_prefix = ""
if getattr(item, "callspec", None) is not None: if getattr(item, "callspec", None) is not None:
# Parametrized test # Parametrized test
msg_prefix = f"[{item.callspec.id}] " if self.emit_parameterized_ids:
pos_id += f"[{item.callspec.id}]" pos_id += f"[{item.callspec.id}]"
else:
msg_prefix = f"[{item.callspec.id}] "
if report.outcome == "failed": if report.outcome == "failed":
exc_repr = report.longrepr exc_repr = report.longrepr
# Test fails due to condition outside of test e.g. xfail # Test fails due to condition outside of test e.g. xfail
@@ -119,7 +133,9 @@ class NeotestResultCollector:
for traceback_entry in reversed(call.excinfo.traceback): for traceback_entry in reversed(call.excinfo.traceback):
if str(traceback_entry.path) == abs_path: if str(traceback_entry.path) == abs_path:
error_line = traceback_entry.lineno error_line = traceback_entry.lineno
errors.append({"message": msg_prefix + error_message, "line": error_line}) errors.append(
{"message": msg_prefix + error_message, "line": error_line}
)
else: else:
# TODO: Figure out how these are returned and how to represent # TODO: Figure out how these are returned and how to represent
raise Exception( raise Exception(
@@ -159,6 +175,7 @@ class NeotestDebugpyPlugin:
""" """
# Reference: https://github.com/microsoft/debugpy/issues/723 # Reference: https://github.com/microsoft/debugpy/issues/723
import threading import threading
try: try:
import pydevd import pydevd
except ImportError: except ImportError:
@@ -180,4 +197,4 @@ class NeotestDebugpyPlugin:
def collect(args): def collect(args):
pytest.main(['--collect-only', '-q'] + args) pytest.main(["--collect-only", "-q"] + args)

View File

@@ -17,7 +17,7 @@ class UnittestNeotestAdapter(NeotestAdapter):
return str(Path(inspect.getmodule(case).__file__).absolute()) # type: ignore return str(Path(inspect.getmodule(case).__file__).absolute()) # type: ignore
def case_id_elems(self, case) -> List[str]: def case_id_elems(self, case) -> List[str]:
if case.__class__.__name__ == '_SubTest': if case.__class__.__name__ == "_SubTest":
case = case.test_case case = case.test_case
file = self.case_file(case) file = self.case_file(case)
elems = [file, case.__class__.__name__] elems = [file, case.__class__.__name__]