refactor(pytest): PytestNeotestAdapter (#24)
Having `NeotestResultCollector` (a pytest plugin) as a inner local class would make the code a bit difficult to read due to quite much indentation. This commit does refactoring on NeotestResultCollector to make it a module-level class with a reference to NeotestAdapter. This refactoring would make easier adding more pytest plugins (e.g., debugger integration) in the future. There should be no changes in behaviors.
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from neotest_python.base import NeotestResult
|
from neotest_python.base import NeotestAdapter, NeotestResult
|
||||||
|
|
||||||
|
|
||||||
class TestRunner(str, Enum):
|
class TestRunner(str, Enum):
|
||||||
@@ -11,7 +11,7 @@ class TestRunner(str, Enum):
|
|||||||
UNITTEST = "unittest"
|
UNITTEST = "unittest"
|
||||||
|
|
||||||
|
|
||||||
def get_adapter(runner: TestRunner):
|
def get_adapter(runner: TestRunner) -> NeotestAdapter:
|
||||||
if runner == TestRunner.PYTEST:
|
if runner == TestRunner.PYTEST:
|
||||||
from .pytest import PytestNeotestAdapter
|
from .pytest import PytestNeotestAdapter
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ parser.add_argument("args", nargs="*")
|
|||||||
def main(argv: List[str]):
|
def main(argv: List[str]):
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
adapter = get_adapter(TestRunner(args.runner))
|
adapter = get_adapter(TestRunner(args.runner))
|
||||||
|
|
||||||
with open(args.stream_file, "w") as stream_file:
|
with open(args.stream_file, "w") as stream_file:
|
||||||
|
|
||||||
def stream(pos_id: str, result: NeotestResult):
|
def stream(pos_id: str, result: NeotestResult):
|
||||||
@@ -50,5 +51,6 @@ def main(argv: List[str]):
|
|||||||
stream_file.flush()
|
stream_file.flush()
|
||||||
|
|
||||||
results = adapter.run(args.args, stream)
|
results = adapter.run(args.args, stream)
|
||||||
|
|
||||||
with open(args.results_file, "w") as results_file:
|
with open(args.results_file, "w") as results_file:
|
||||||
json.dump(results, results_file)
|
json.dump(results, results_file)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
|
import abc
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
class NeotestResultStatus(str, Enum):
|
class NeotestResultStatus(str, Enum):
|
||||||
@@ -29,7 +30,8 @@ else:
|
|||||||
NeotestResult = Dict
|
NeotestResult = Dict
|
||||||
|
|
||||||
|
|
||||||
class NeotestAdapter:
|
class NeotestAdapter(abc.ABC):
|
||||||
|
|
||||||
def update_result(
|
def update_result(
|
||||||
self, base: Optional[NeotestResult], update: NeotestResult
|
self, base: Optional[NeotestResult], update: NeotestResult
|
||||||
) -> NeotestResult:
|
) -> NeotestResult:
|
||||||
@@ -40,3 +42,8 @@ class NeotestAdapter:
|
|||||||
"errors": (base.get("errors") or []) + (update.get("errors") or []) or None,
|
"errors": (base.get("errors") or []) + (update.get("errors") or []) or None,
|
||||||
"short": (base.get("short") or "") + (update.get("short") or ""),
|
"short": (base.get("short") or "") + (update.get("short") or ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def run(self, args: List[str], stream: Callable):
|
||||||
|
del args, stream
|
||||||
|
raise NotImplementedError
|
||||||
|
@@ -10,7 +10,33 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class PytestNeotestAdapter(NeotestAdapter):
|
class PytestNeotestAdapter(NeotestAdapter):
|
||||||
def get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]:
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
args: List[str],
|
||||||
|
stream: Callable[[str, NeotestResult], None],
|
||||||
|
) -> Dict[str, NeotestResult]:
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
result_collector = NeotestResultCollector(self, stream=stream)
|
||||||
|
pytest.main(args=args, plugins=[result_collector])
|
||||||
|
return result_collector.results
|
||||||
|
|
||||||
|
|
||||||
|
class NeotestResultCollector:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
adapter: PytestNeotestAdapter,
|
||||||
|
stream: Callable[[str, NeotestResult], None],
|
||||||
|
):
|
||||||
|
self.stream = stream
|
||||||
|
self.adapter = adapter
|
||||||
|
|
||||||
|
self.pytest_config: "Config" = None # type: ignore
|
||||||
|
self.results: Dict[str, NeotestResult] = {}
|
||||||
|
|
||||||
|
def _get_short_output(self, config: "Config", report: "TestReport") -> Optional[str]:
|
||||||
from _pytest.terminal import TerminalReporter
|
from _pytest.terminal import TerminalReporter
|
||||||
|
|
||||||
buffer = StringIO()
|
buffer = StringIO()
|
||||||
@@ -32,88 +58,75 @@ class PytestNeotestAdapter(NeotestAdapter):
|
|||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
return buffer.read()
|
return buffer.read()
|
||||||
|
|
||||||
def run(
|
def pytest_deselected(self, items: List["pytest.Item"]):
|
||||||
self, args: List[str], stream: Callable[[str, NeotestResult], None]
|
for report in items:
|
||||||
) -> Dict[str, NeotestResult]:
|
file_path, *name_path = report.nodeid.split("::")
|
||||||
results: Dict[str, NeotestResult] = {}
|
abs_path = str(Path(self.pytest_config.rootdir, file_path))
|
||||||
pytest_config: "Config"
|
test_name, *namespaces = reversed(name_path)
|
||||||
from _pytest._code.code import ExceptionChainRepr
|
valid_test_name, *params = test_name.split("[") # ]
|
||||||
|
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
|
||||||
|
result = self.adapter.update_result(
|
||||||
|
self.results.get(pos_id),
|
||||||
|
{
|
||||||
|
"short": None,
|
||||||
|
"status": NeotestResultStatus.SKIPPED,
|
||||||
|
"errors": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not params:
|
||||||
|
self.stream(pos_id, result)
|
||||||
|
self.results[pos_id] = result
|
||||||
|
|
||||||
class NeotestResultCollector:
|
def pytest_cmdline_main(self, config: "Config"):
|
||||||
@staticmethod
|
self.pytest_config = config
|
||||||
def pytest_deselected(items: List):
|
|
||||||
for report in items:
|
|
||||||
file_path, *name_path = report.nodeid.split("::")
|
|
||||||
abs_path = str(Path(pytest_config.rootdir, file_path))
|
|
||||||
test_name, *namespaces = reversed(name_path)
|
|
||||||
valid_test_name, *params = test_name.split("[") # ]
|
|
||||||
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
|
|
||||||
result = self.update_result(
|
|
||||||
results.get(pos_id),
|
|
||||||
{
|
|
||||||
"short": None,
|
|
||||||
"status": NeotestResultStatus.SKIPPED,
|
|
||||||
"errors": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not params:
|
|
||||||
stream(pos_id, result)
|
|
||||||
results[pos_id] = result
|
|
||||||
|
|
||||||
@staticmethod
|
def pytest_runtest_logreport(self, report: "TestReport"):
|
||||||
def pytest_cmdline_main(config: "Config"):
|
if report.when != "call" and not (
|
||||||
nonlocal pytest_config
|
report.outcome == "skipped" and report.when == "setup"
|
||||||
pytest_config = config
|
):
|
||||||
|
return
|
||||||
|
|
||||||
@staticmethod
|
file_path, *name_path = report.nodeid.split("::")
|
||||||
def pytest_runtest_logreport(report: "TestReport"):
|
abs_path = str(Path(self.pytest_config.rootdir, file_path))
|
||||||
if report.when != "call" and not (
|
test_name, *namespaces = reversed(name_path)
|
||||||
report.outcome == "skipped" and report.when == "setup"
|
valid_test_name, *params = test_name.split("[") # ]
|
||||||
):
|
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
|
||||||
return
|
|
||||||
file_path, *name_path = report.nodeid.split("::")
|
|
||||||
abs_path = str(Path(pytest_config.rootdir, file_path))
|
|
||||||
test_name, *namespaces = reversed(name_path)
|
|
||||||
valid_test_name, *params = test_name.split("[") # ]
|
|
||||||
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
|
|
||||||
|
|
||||||
errors: List[NeotestError] = []
|
errors: List[NeotestError] = []
|
||||||
short = self.get_short_output(pytest_config, report)
|
short = self._get_short_output(self.pytest_config, report)
|
||||||
if report.outcome == "failed":
|
|
||||||
exc_repr = report.longrepr
|
if report.outcome == "failed":
|
||||||
# Test fails due to condition outside of test e.g. xfail
|
from _pytest._code.code import ExceptionChainRepr
|
||||||
if isinstance(exc_repr, str):
|
|
||||||
errors.append({"message": exc_repr, "line": None})
|
exc_repr = report.longrepr
|
||||||
# Test failed internally
|
# Test fails due to condition outside of test e.g. xfail
|
||||||
elif isinstance(exc_repr, ExceptionChainRepr):
|
if isinstance(exc_repr, str):
|
||||||
reprtraceback = exc_repr.reprtraceback
|
errors.append({"message": exc_repr, "line": None})
|
||||||
error_message = exc_repr.reprcrash.message # type: ignore
|
# Test failed internally
|
||||||
error_line = None
|
elif isinstance(exc_repr, ExceptionChainRepr):
|
||||||
for repr in reversed(reprtraceback.reprentries):
|
reprtraceback = exc_repr.reprtraceback
|
||||||
if (
|
error_message = exc_repr.reprcrash.message # type: ignore
|
||||||
hasattr(repr, "reprfileloc")
|
error_line = None
|
||||||
and repr.reprfileloc.path == file_path
|
for repr in reversed(reprtraceback.reprentries):
|
||||||
):
|
if (
|
||||||
error_line = repr.reprfileloc.lineno - 1
|
hasattr(repr, "reprfileloc")
|
||||||
errors.append({"message": error_message, "line": error_line})
|
and repr.reprfileloc.path == file_path
|
||||||
else:
|
):
|
||||||
# TODO: Figure out how these are returned and how to represent
|
error_line = repr.reprfileloc.lineno - 1
|
||||||
raise Exception(
|
errors.append({"message": error_message, "line": error_line})
|
||||||
"Unhandled error type, please report to neotest-python repo"
|
else:
|
||||||
)
|
# TODO: Figure out how these are returned and how to represent
|
||||||
result = self.update_result(
|
raise Exception(
|
||||||
results.get(pos_id),
|
"Unhandled error type, please report to neotest-python repo"
|
||||||
{
|
|
||||||
"short": short,
|
|
||||||
"status": NeotestResultStatus(report.outcome),
|
|
||||||
"errors": errors,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
if not params:
|
result: NeotestResult = self.adapter.update_result(
|
||||||
stream(pos_id, result)
|
self.results.get(pos_id),
|
||||||
results[pos_id] = result
|
{
|
||||||
|
"short": short,
|
||||||
import pytest
|
"status": NeotestResultStatus(report.outcome),
|
||||||
|
"errors": errors,
|
||||||
pytest.main(args=args, plugins=[NeotestResultCollector])
|
},
|
||||||
return results
|
)
|
||||||
|
if not params:
|
||||||
|
self.stream(pos_id, result)
|
||||||
|
self.results[pos_id] = result
|
||||||
|
Reference in New Issue
Block a user