diff --git a/neotest_python/pytest.py b/neotest_python/pytest.py index 979200c..29f9429 100644 --- a/neotest_python/pytest.py +++ b/neotest_python/pytest.py @@ -1,6 +1,6 @@ from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus @@ -18,7 +18,10 @@ class PytestNeotestAdapter(NeotestAdapter): import pytest result_collector = NeotestResultCollector(self, stream=stream) - pytest.main(args=args, plugins=[result_collector]) + pytest.main(args=args, plugins=[ + result_collector, + NeotestDebugpyPlugin(), + ]) return result_collector.results @@ -130,3 +133,43 @@ class NeotestResultCollector: if not params: self.stream(pos_id, result) self.results[pos_id] = result + + +class NeotestDebugpyPlugin: + """A pytest plugin that would make debugpy stop at thrown exceptions.""" + + def pytest_exception_interact( + self, + node: Union['pytest.Item', 'pytest.Collector'], + call: 'pytest.CallInfo', + report: Union['pytest.CollectReport', 'pytest.TestReport'], + ): + # call.excinfo: _pytest._code.ExceptionInfo + self.maybe_debugpy_postmortem(call.excinfo._excinfo) + + @staticmethod + def maybe_debugpy_postmortem(excinfo): + """Make the debugpy debugger enter and stop at a raised exception. + + excinfo: A (type(e), e, e.__traceback__) tuple. See sys.exc_info() + """ + # Reference: https://github.com/microsoft/debugpy/issues/723 + import threading + try: + import pydevd + except ImportError: + return # debugpy or pydevd not available, do nothing + + py_db = pydevd.get_global_debugger() + if py_db is None: + # Do nothing if not running with a DAP debugger, + # e.g. neotest was invoked with {strategy = dap} + return + + thread = threading.current_thread() + additional_info = py_db.set_additional_thread_info(thread) + additional_info.is_tracing += 1 + try: + py_db.stop_on_unhandled_exception(py_db, thread, additional_info, excinfo) + finally: + additional_info.is_tracing -= 1