From 2dc9c95fe9c895cb086ddd749875ad714553e336 Mon Sep 17 00:00:00 2001 From: Jongwook Choi Date: Sat, 29 Oct 2022 08:23:40 -0400 Subject: [PATCH] refactor(pytest): PytestNeotestAdapter (#24) Having `NeotestResultCollector` (a pytest plugin) as a inner local class would make the code a bit difficult to read due to quite much indentation. This commit does refactoring on NeotestResultCollector to make it a module-level class with a reference to NeotestAdapter. This refactoring would make easier adding more pytest plugins (e.g., debugger integration) in the future. There should be no changes in behaviors. --- neotest_python/__init__.py | 6 +- neotest_python/base.py | 11 ++- neotest_python/pytest.py | 175 ++++++++++++++++++++----------------- 3 files changed, 107 insertions(+), 85 deletions(-) diff --git a/neotest_python/__init__.py b/neotest_python/__init__.py index 36d97d0..a3d1f56 100644 --- a/neotest_python/__init__.py +++ b/neotest_python/__init__.py @@ -3,7 +3,7 @@ import json from enum import Enum from typing import List -from neotest_python.base import NeotestResult +from neotest_python.base import NeotestAdapter, NeotestResult class TestRunner(str, Enum): @@ -11,7 +11,7 @@ class TestRunner(str, Enum): UNITTEST = "unittest" -def get_adapter(runner: TestRunner): +def get_adapter(runner: TestRunner) -> NeotestAdapter: if runner == TestRunner.PYTEST: from .pytest import PytestNeotestAdapter @@ -43,6 +43,7 @@ parser.add_argument("args", nargs="*") def main(argv: List[str]): args = parser.parse_args(argv) adapter = get_adapter(TestRunner(args.runner)) + with open(args.stream_file, "w") as stream_file: def stream(pos_id: str, result: NeotestResult): @@ -50,5 +51,6 @@ def main(argv: List[str]): stream_file.flush() results = adapter.run(args.args, stream) + with open(args.results_file, "w") as results_file: json.dump(results, results_file) diff --git a/neotest_python/base.py b/neotest_python/base.py index 067b646..9e2975c 100644 --- a/neotest_python/base.py +++ b/neotest_python/base.py @@ -1,5 +1,6 @@ +import abc from enum import Enum -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional class NeotestResultStatus(str, Enum): @@ -29,7 +30,8 @@ else: NeotestResult = Dict -class NeotestAdapter: +class NeotestAdapter(abc.ABC): + def update_result( self, base: Optional[NeotestResult], update: NeotestResult ) -> NeotestResult: @@ -40,3 +42,8 @@ class NeotestAdapter: "errors": (base.get("errors") or []) + (update.get("errors") or []) or None, "short": (base.get("short") or "") + (update.get("short") or ""), } + + @abc.abstractmethod + def run(self, args: List[str], stream: Callable): + del args, stream + raise NotImplementedError diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index f7cf7b6..47d0384 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -10,7 +10,33 @@ if TYPE_CHECKING: class PytestNeotestAdapter(NeotestAdapter): - def get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]: + + def run( + self, + args: List[str], + stream: Callable[[str, NeotestResult], None], + ) -> Dict[str, NeotestResult]: + import pytest + + result_collector = NeotestResultCollector(self, stream=stream) + pytest.main(args=args, plugins=[result_collector]) + return result_collector.results + + +class NeotestResultCollector: + + def __init__( + self, + adapter: PytestNeotestAdapter, + stream: Callable[[str, NeotestResult], None], + ): + self.stream = stream + self.adapter = adapter + + self.pytest_config: "Config" = None # type: ignore + self.results: Dict[str, NeotestResult] = {} + + def _get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]: from _pytest.terminal import TerminalReporter buffer = StringIO() @@ -32,88 +58,75 @@ class PytestNeotestAdapter(NeotestAdapter): buffer.seek(0) return buffer.read() - def run( - self, args: List[str], stream: Callable[[str, NeotestResult], None] - ) -> Dict[str, NeotestResult]: - results: Dict[str, NeotestResult] = {} - pytest_config: "Config" - from _pytest._code.code import ExceptionChainRepr + def pytest_deselected(self, items: List["pytest.Item"]): + for report in items: + file_path, *name_path = report.nodeid.split("::") + abs_path = str(Path(self.pytest_config.rootdir, file_path)) + test_name, *namespaces = reversed(name_path) + valid_test_name, *params = test_name.split("[") # ] + pos_id = "::".join([abs_path, *namespaces, valid_test_name]) + result = self.adapter.update_result( + self.results.get(pos_id), + { + "short": None, + "status": NeotestResultStatus.SKIPPED, + "errors": [], + }, + ) + if not params: + self.stream(pos_id, result) + self.results[pos_id] = result - class NeotestResultCollector: - @staticmethod - def pytest_deselected(items: List): - for report in items: - file_path, *name_path = report.nodeid.split("::") - abs_path = str(Path(pytest_config.rootdir, file_path)) - test_name, *namespaces = reversed(name_path) - valid_test_name, *params = test_name.split("[") # ] - pos_id = "::".join([abs_path, *namespaces, valid_test_name]) - result = self.update_result( - results.get(pos_id), - { - "short": None, - "status": NeotestResultStatus.SKIPPED, - "errors": [], - }, - ) - if not params: - stream(pos_id, result) - results[pos_id] = result + def pytest_cmdline_main(self, config: "Config"): + self.pytest_config = config - @staticmethod - def pytest_cmdline_main(config: "Config"): - nonlocal pytest_config - pytest_config = config + def pytest_runtest_logreport(self, report: "TestReport"): + if report.when != "call" and not ( + report.outcome == "skipped" and report.when == "setup" + ): + return - @staticmethod - def pytest_runtest_logreport(report: "TestReport"): - if report.when != "call" and not ( - report.outcome == "skipped" and report.when == "setup" - ): - return - file_path, *name_path = report.nodeid.split("::") - abs_path = str(Path(pytest_config.rootdir, file_path)) - test_name, *namespaces = reversed(name_path) - valid_test_name, *params = test_name.split("[") # ] - pos_id = "::".join([abs_path, *namespaces, valid_test_name]) + file_path, *name_path = report.nodeid.split("::") + abs_path = str(Path(self.pytest_config.rootdir, file_path)) + test_name, *namespaces = reversed(name_path) + valid_test_name, *params = test_name.split("[") # ] + pos_id = "::".join([abs_path, *namespaces, valid_test_name]) - errors: List[NeotestError] = [] - short = self.get_short_output(pytest_config, report) - if report.outcome == "failed": - exc_repr = report.longrepr - # Test fails due to condition outside of test e.g. xfail - if isinstance(exc_repr, str): - errors.append({"message": exc_repr, "line": None}) - # Test failed internally - elif isinstance(exc_repr, ExceptionChainRepr): - reprtraceback = exc_repr.reprtraceback - error_message = exc_repr.reprcrash.message # type: ignore - error_line = None - for repr in reversed(reprtraceback.reprentries): - if ( - hasattr(repr, "reprfileloc") - and repr.reprfileloc.path == file_path - ): - error_line = repr.reprfileloc.lineno - 1 - errors.append({"message": error_message, "line": error_line}) - else: - # TODO: Figure out how these are returned and how to represent - raise Exception( - "Unhandled error type, please report to neotest-python repo" - ) - result = self.update_result( - results.get(pos_id), - { - "short": short, - "status": NeotestResultStatus(report.outcome), - "errors": errors, - }, + errors: List[NeotestError] = [] + short = self._get_short_output(self.pytest_config, report) + + if report.outcome == "failed": + from _pytest._code.code import ExceptionChainRepr + + exc_repr = report.longrepr + # Test fails due to condition outside of test e.g. xfail + if isinstance(exc_repr, str): + errors.append({"message": exc_repr, "line": None}) + # Test failed internally + elif isinstance(exc_repr, ExceptionChainRepr): + reprtraceback = exc_repr.reprtraceback + error_message = exc_repr.reprcrash.message # type: ignore + error_line = None + for repr in reversed(reprtraceback.reprentries): + if ( + hasattr(repr, "reprfileloc") + and repr.reprfileloc.path == file_path + ): + error_line = repr.reprfileloc.lineno - 1 + errors.append({"message": error_message, "line": error_line}) + else: + # TODO: Figure out how these are returned and how to represent + raise Exception( + "Unhandled error type, please report to neotest-python repo" ) - if not params: - stream(pos_id, result) - results[pos_id] = result - - import pytest - - pytest.main(args=args, plugins=[NeotestResultCollector]) - return results + result: NeotestResult = self.adapter.update_result( + self.results.get(pos_id), + { + "short": short, + "status": NeotestResultStatus(report.outcome), + "errors": errors, + }, + ) + if not params: + self.stream(pos_id, result) + self.results[pos_id] = result