refactor(pytest): PytestNeotestAdapter (#24)

Having `NeotestResultCollector` (a pytest plugin) as a inner local
class would make the code a bit difficult to read due to quite much
indentation. This commit does refactoring on NeotestResultCollector
to make it a module-level class with a reference to NeotestAdapter.

This refactoring would make easier adding more pytest plugins
(e.g., debugger integration) in the future.

There should be no changes in behaviors.
This commit is contained in:
Jongwook Choi
2022-10-29 08:23:40 -04:00
committed by GitHub
parent 9e2db9375c
commit 2dc9c95fe9
3 changed files with 107 additions and 85 deletions

View File

@@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
from typing import List from typing import List
from neotest_python.base import NeotestResult from neotest_python.base import NeotestAdapter, NeotestResult
class TestRunner(str, Enum): class TestRunner(str, Enum):
@@ -11,7 +11,7 @@ class TestRunner(str, Enum):
UNITTEST = "unittest" UNITTEST = "unittest"
def get_adapter(runner: TestRunner): def get_adapter(runner: TestRunner) -> NeotestAdapter:
if runner == TestRunner.PYTEST: if runner == TestRunner.PYTEST:
from .pytest import PytestNeotestAdapter from .pytest import PytestNeotestAdapter
@@ -43,6 +43,7 @@ parser.add_argument("args", nargs="*")
def main(argv: List[str]): def main(argv: List[str]):
args = parser.parse_args(argv) args = parser.parse_args(argv)
adapter = get_adapter(TestRunner(args.runner)) adapter = get_adapter(TestRunner(args.runner))
with open(args.stream_file, "w") as stream_file: with open(args.stream_file, "w") as stream_file:
def stream(pos_id: str, result: NeotestResult): def stream(pos_id: str, result: NeotestResult):
@@ -50,5 +51,6 @@ def main(argv: List[str]):
stream_file.flush() stream_file.flush()
results = adapter.run(args.args, stream) results = adapter.run(args.args, stream)
with open(args.results_file, "w") as results_file: with open(args.results_file, "w") as results_file:
json.dump(results, results_file) json.dump(results, results_file)

View File

@@ -1,5 +1,6 @@
import abc
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional
class NeotestResultStatus(str, Enum): class NeotestResultStatus(str, Enum):
@@ -29,7 +30,8 @@ else:
NeotestResult = Dict NeotestResult = Dict
class NeotestAdapter: class NeotestAdapter(abc.ABC):
def update_result( def update_result(
self, base: Optional[NeotestResult], update: NeotestResult self, base: Optional[NeotestResult], update: NeotestResult
) -> NeotestResult: ) -> NeotestResult:
@@ -40,3 +42,8 @@ class NeotestAdapter:
"errors": (base.get("errors") or []) + (update.get("errors") or []) or None, "errors": (base.get("errors") or []) + (update.get("errors") or []) or None,
"short": (base.get("short") or "") + (update.get("short") or ""), "short": (base.get("short") or "") + (update.get("short") or ""),
} }
@abc.abstractmethod
def run(self, args: List[str], stream: Callable):
del args, stream
raise NotImplementedError

View File

@@ -10,7 +10,33 @@ if TYPE_CHECKING:
class PytestNeotestAdapter(NeotestAdapter): class PytestNeotestAdapter(NeotestAdapter):
def get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]:
def run(
self,
args: List[str],
stream: Callable[[str, NeotestResult], None],
) -> Dict[str, NeotestResult]:
import pytest
result_collector = NeotestResultCollector(self, stream=stream)
pytest.main(args=args, plugins=[result_collector])
return result_collector.results
class NeotestResultCollector:
def __init__(
self,
adapter: PytestNeotestAdapter,
stream: Callable[[str, NeotestResult], None],
):
self.stream = stream
self.adapter = adapter
self.pytest_config: "Config" = None # type: ignore
self.results: Dict[str, NeotestResult] = {}
def _get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]:
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
buffer = StringIO() buffer = StringIO()
@@ -32,88 +58,75 @@ class PytestNeotestAdapter(NeotestAdapter):
buffer.seek(0) buffer.seek(0)
return buffer.read() return buffer.read()
def run( def pytest_deselected(self, items: List["pytest.Item"]):
self, args: List[str], stream: Callable[[str, NeotestResult], None] for report in items:
) -> Dict[str, NeotestResult]: file_path, *name_path = report.nodeid.split("::")
results: Dict[str, NeotestResult] = {} abs_path = str(Path(self.pytest_config.rootdir, file_path))
pytest_config: "Config" test_name, *namespaces = reversed(name_path)
from _pytest._code.code import ExceptionChainRepr valid_test_name, *params = test_name.split("[") # ]
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
result = self.adapter.update_result(
self.results.get(pos_id),
{
"short": None,
"status": NeotestResultStatus.SKIPPED,
"errors": [],
},
)
if not params:
self.stream(pos_id, result)
self.results[pos_id] = result
class NeotestResultCollector: def pytest_cmdline_main(self, config: "Config"):
@staticmethod self.pytest_config = config
def pytest_deselected(items: List):
for report in items:
file_path, *name_path = report.nodeid.split("::")
abs_path = str(Path(pytest_config.rootdir, file_path))
test_name, *namespaces = reversed(name_path)
valid_test_name, *params = test_name.split("[") # ]
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
result = self.update_result(
results.get(pos_id),
{
"short": None,
"status": NeotestResultStatus.SKIPPED,
"errors": [],
},
)
if not params:
stream(pos_id, result)
results[pos_id] = result
@staticmethod def pytest_runtest_logreport(self, report: "TestReport"):
def pytest_cmdline_main(config: "Config"): if report.when != "call" and not (
nonlocal pytest_config report.outcome == "skipped" and report.when == "setup"
pytest_config = config ):
return
@staticmethod file_path, *name_path = report.nodeid.split("::")
def pytest_runtest_logreport(report: "TestReport"): abs_path = str(Path(self.pytest_config.rootdir, file_path))
if report.when != "call" and not ( test_name, *namespaces = reversed(name_path)
report.outcome == "skipped" and report.when == "setup" valid_test_name, *params = test_name.split("[") # ]
): pos_id = "::".join([abs_path, *namespaces, valid_test_name])
return
file_path, *name_path = report.nodeid.split("::")
abs_path = str(Path(pytest_config.rootdir, file_path))
test_name, *namespaces = reversed(name_path)
valid_test_name, *params = test_name.split("[") # ]
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
errors: List[NeotestError] = [] errors: List[NeotestError] = []
short = self.get_short_output(pytest_config, report) short = self._get_short_output(self.pytest_config, report)
if report.outcome == "failed":
exc_repr = report.longrepr if report.outcome == "failed":
# Test fails due to condition outside of test e.g. xfail from _pytest._code.code import ExceptionChainRepr
if isinstance(exc_repr, str):
errors.append({"message": exc_repr, "line": None}) exc_repr = report.longrepr
# Test failed internally # Test fails due to condition outside of test e.g. xfail
elif isinstance(exc_repr, ExceptionChainRepr): if isinstance(exc_repr, str):
reprtraceback = exc_repr.reprtraceback errors.append({"message": exc_repr, "line": None})
error_message = exc_repr.reprcrash.message # type: ignore # Test failed internally
error_line = None elif isinstance(exc_repr, ExceptionChainRepr):
for repr in reversed(reprtraceback.reprentries): reprtraceback = exc_repr.reprtraceback
if ( error_message = exc_repr.reprcrash.message # type: ignore
hasattr(repr, "reprfileloc") error_line = None
and repr.reprfileloc.path == file_path for repr in reversed(reprtraceback.reprentries):
): if (
error_line = repr.reprfileloc.lineno - 1 hasattr(repr, "reprfileloc")
errors.append({"message": error_message, "line": error_line}) and repr.reprfileloc.path == file_path
else: ):
# TODO: Figure out how these are returned and how to represent error_line = repr.reprfileloc.lineno - 1
raise Exception( errors.append({"message": error_message, "line": error_line})
"Unhandled error type, please report to neotest-python repo" else:
) # TODO: Figure out how these are returned and how to represent
result = self.update_result( raise Exception(
results.get(pos_id), "Unhandled error type, please report to neotest-python repo"
{
"short": short,
"status": NeotestResultStatus(report.outcome),
"errors": errors,
},
) )
if not params: result: NeotestResult = self.adapter.update_result(
stream(pos_id, result) self.results.get(pos_id),
results[pos_id] = result {
"short": short,
import pytest "status": NeotestResultStatus(report.outcome),
"errors": errors,
pytest.main(args=args, plugins=[NeotestResultCollector]) },
return results )
if not params:
self.stream(pos_id, result)
self.results[pos_id] = result