From 48bf141103b94c9384e5542cd185b291909ac305 Mon Sep 17 00:00:00 2001 From: Adrian Frischkorn <5385601+afrischk@users.noreply.github.com> Date: Sun, 10 Dec 2023 20:15:08 +0100 Subject: [PATCH] feat: django support. (#54) --- lua/neotest-python/base.lua | 2 +- lua/neotest-python/init.lua | 4 +- neotest_python/__init__.py | 5 ++ neotest_python/django_unittest.py | 132 ++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 3 deletions(-) create mode 100644 neotest_python/django_unittest.py diff --git a/lua/neotest-python/base.lua b/lua/neotest-python/base.lua index 90f703e..c529047 100644 --- a/lua/neotest-python/base.lua +++ b/lua/neotest-python/base.lua @@ -17,7 +17,7 @@ M.module_exists = function(module, python_command) return lib.process.run(vim.tbl_flatten({ python_command, "-c", - "import imp; imp.find_module('" .. module .. "')", + "import " .. module, })) == 0 end diff --git a/lua/neotest-python/init.lua b/lua/neotest-python/init.lua index 7e253db..f5da1db 100644 --- a/lua/neotest-python/init.lua +++ b/lua/neotest-python/init.lua @@ -59,10 +59,10 @@ local get_runner = function(python_command) if vim_test_runner == "pyunit" then return "unittest" end - if vim_test_runner and lib.func_util.index({ "unittest", "pytest" }, vim_test_runner) then + if vim_test_runner and lib.func_util.index({ "unittest", "pytest", "django" }, vim_test_runner) then return vim_test_runner end - local runner = base.module_exists("pytest", python_command) and "pytest" or "unittest" + local runner = base.module_exists("pytest", python_command) and "pytest" or base.module_exists("django", python_command) and "django" or "unittest" stored_runners[command_str] = runner return runner end diff --git a/neotest_python/__init__.py b/neotest_python/__init__.py index e8c9871..e898d36 100644 --- a/neotest_python/__init__.py +++ b/neotest_python/__init__.py @@ -9,6 +9,7 @@ from neotest_python.base import NeotestAdapter, NeotestResult class TestRunner(str, Enum): PYTEST = "pytest" UNITTEST = "unittest" + DJANGO = "django" def get_adapter(runner: TestRunner) -> NeotestAdapter: @@ -20,6 +21,10 @@ def get_adapter(runner: TestRunner) -> NeotestAdapter: from .unittest import UnittestNeotestAdapter return UnittestNeotestAdapter() + elif runner == TestRunner.DJANGO: + from .django_unittest import DjangoNeotestAdapter + + return DjangoNeotestAdapter() raise NotImplementedError(runner) diff --git a/neotest_python/django_unittest.py b/neotest_python/django_unittest.py new file mode 100644 index 0000000..1a8e7bb --- /dev/null +++ b/neotest_python/django_unittest.py @@ -0,0 +1,132 @@ +import inspect +import subprocess +import os +import sys +import traceback +import unittest +from argparse import ArgumentParser +from pathlib import Path +from types import TracebackType +from typing import Any, Tuple, Dict, List +from unittest import TestCase +from unittest.runner import TextTestResult +from django import setup as django_setup +from django.test.runner import DiscoverRunner +from .base import NeotestAdapter, NeotestError, NeotestResultStatus + + +class CaseUtilsMixin: + def case_file(self, case) -> str: + return str(Path(inspect.getmodule(case).__file__).absolute()) + + def case_id_elems(self, case) -> List[str]: + file = self.case_file(case) + elems = [file, case.__class__.__name__] + if isinstance(case, TestCase): + elems.append(case._testMethodName) + return elems + + def case_id(self, case: "TestCase | TestSuite") -> str: + return "::".join(self.case_id_elems(case)) + + +class DjangoNeotestAdapter(CaseUtilsMixin, NeotestAdapter): + def convert_args(self, case_id: str, args: List[str]) -> List[str]: + """Converts a neotest ID into test specifier for unittest""" + path, *child_ids = case_id.split("::") + if not child_ids: + child_ids = [] + relative_file = os.path.relpath(path, os.getcwd()) + relative_stem = os.path.splitext(relative_file)[0] + relative_dotted = relative_stem.replace(os.sep, ".") + return [*args, ".".join([relative_dotted, *child_ids])] + + def run(self, args: List[str], _) -> Dict: + errs: Dict[str, Tuple[Exception, Any, TracebackType]] = {} + results = {} + + class NeotestTextTestResult(CaseUtilsMixin, TextTestResult): + def addFailure(_, test: TestCase, err) -> None: + errs[self.case_id(test)] = err + return super().addFailure(test, err) + + def addError(_, test: TestCase, err) -> None: + errs[self.case_id(test)] = err + return super().addError(test, err) + + def addSuccess(_, test: TestCase) -> None: + results[self.case_id(test)] = { + "status": NeotestResultStatus.PASSED, + } + + class DjangoUnittestRunner(CaseUtilsMixin, DiscoverRunner): + def __init__(self, **kwargs): + django_setup() + DiscoverRunner.__init__(self, **kwargs) + + @classmethod + def add_arguments(cls, parser): + DiscoverRunner.add_arguments(parser) + parser.add_argument( + "--verbosity", + nargs="?", + default=2 + ) + parser.add_argument( + "--failfast", + action="store_true", + ) + + # override + def get_resultclass(self): + return NeotestTextTestResult + + def collect_results(self, django_test_results, neotest_results): + for case, message in ( + django_test_results.failures + django_test_results.errors + ): + case_id = self.case_id(case) + error_line = None + case_file = self.case_file(case) + if case_id in errs: + trace = errs[case_id][2] + summary = traceback.extract_tb(trace) + for frame in reversed(summary): + if frame.filename == case_file: + error_line = frame.lineno - 1 + break + neotest_results[case_id] = { + "status": NeotestResultStatus.FAILED, + "errors": [{"message": message, "line": error_line}], + "short": None, + } + for case, message in django_test_results.skipped: + neotest_results[self.case_id(case)] = { + "short": None, + "status": NeotestResultStatus.SKIPPED, + "errors": None, + } + + # override + def suite_result(self, suite, suite_results, **kwargs): + """Collect Django test suite results and convert them to Neotest compatible results.""" + self.collect_results(suite_results, results) + return ( + len(suite_results.failures) + + len(suite_results.errors) + + len(suite_results.unexpectedSuccesses) + ) + + # Make sure we can import relative to current path + sys.path.insert(0, os.getcwd()) + # Prepend an executable name which is just used in output + argv = ["neotest-python"] + self.convert_args(args[-1], args[:-1]) + # parse args + parser = ArgumentParser() + DjangoUnittestRunner.add_arguments(parser) + # run tests + runner = DjangoUnittestRunner( + **vars(parser.parse_args(argv[1:-1])) # parse plugin config args + ) + runner.run_tests(test_labels=[argv[-1]]) # pass test label + return results