feat: streamed results

This commit is contained in:
Rónán Carrigan
2022-07-17 22:22:29 +01:00
parent 023b7bda64
commit aaf83100b6
4 changed files with 52 additions and 8 deletions

View File

@@ -92,12 +92,20 @@ end
function PythonNeotestAdapter.build_spec(args) function PythonNeotestAdapter.build_spec(args)
local position = args.tree:data() local position = args.tree:data()
local results_path = async.fn.tempname() local results_path = async.fn.tempname()
local stream_path = async.fn.tempname()
local x = io.open(stream_path, "w")
x:write("")
x:close()
local root = PythonNeotestAdapter.root(position.path) local root = PythonNeotestAdapter.root(position.path)
local python = base.get_python_command(root) local python = base.get_python_command(root)
local runner = get_runner(python) local runner = get_runner(python)
local stream_data, stop_stream = lib.files.stream_lines(stream_path)
local script_args = vim.tbl_flatten({ local script_args = vim.tbl_flatten({
"--results-file", "--results-file",
results_path, results_path,
"--stream-file",
stream_path,
"--runner", "--runner",
runner, runner,
"--", "--",
@@ -112,11 +120,24 @@ function PythonNeotestAdapter.build_spec(args)
script_args, script_args,
}) })
local strategy_config = get_strategy_config(args.strategy, python, python_script, script_args) local strategy_config = get_strategy_config(args.strategy, python, python_script, script_args)
---@type neotest.RunSpec
return { return {
command = command, command = command,
context = { context = {
results_path = results_path, results_path = results_path,
stop_stream = stop_stream,
}, },
stream = function()
return function()
local lines = stream_data()
local results = {}
for _, line in ipairs(lines) do
local result = vim.json.decode(line, { luanil = { object = true } })
results[result.id] = result.result
end
return results
end
end,
strategy = strategy_config, strategy = strategy_config,
} }
end end
@@ -126,6 +147,7 @@ end
---@param result neotest.StrategyResult ---@param result neotest.StrategyResult
---@return neotest.Result[] ---@return neotest.Result[]
function PythonNeotestAdapter.results(spec, result) function PythonNeotestAdapter.results(spec, result)
spec.context.stop_stream()
local success, data = pcall(lib.files.read, spec.context.results_path) local success, data = pcall(lib.files.read, spec.context.results_path)
if not success then if not success then
data = "{}" data = "{}"

View File

@@ -3,6 +3,8 @@ import json
from enum import Enum from enum import Enum
from typing import List from typing import List
from neotest_python.base import NeotestResult
class TestRunner(str, Enum): class TestRunner(str, Enum):
PYTEST = "pytest" PYTEST = "pytest"
@@ -29,12 +31,24 @@ parser.add_argument(
required=True, required=True,
help="File to store result JSON in", help="File to store result JSON in",
) )
parser.add_argument(
"--stream-file",
dest="stream_file",
required=True,
help="File to stream result JSON to",
)
parser.add_argument("args", nargs="*") parser.add_argument("args", nargs="*")
def main(argv: List[str]): def main(argv: List[str]):
args = parser.parse_args(argv) args = parser.parse_args(argv)
adapter = get_adapter(TestRunner(args.runner)) adapter = get_adapter(TestRunner(args.runner))
results = adapter.run(args.args) with open(args.stream_file, "w") as stream_file:
def stream(pos_id: str, result: NeotestResult):
stream_file.write(json.dumps({"id": pos_id, "result": result}) + "\n")
stream_file.flush()
results = adapter.run(args.args, stream)
with open(args.results_file, "w") as results_file: with open(args.results_file, "w") as results_file:
json.dump(results, results_file) json.dump(results, results_file)

View File

@@ -1,6 +1,6 @@
from io import StringIO from io import StringIO
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, cast from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus
@@ -32,7 +32,9 @@ class PytestNeotestAdapter(NeotestAdapter):
buffer.seek(0) buffer.seek(0)
return buffer.read() return buffer.read()
def run(self, args: List[str]) -> Dict[str, NeotestResult]: def run(
self, args: List[str], stream: Callable[[str, NeotestResult], None]
) -> Dict[str, NeotestResult]:
results: Dict[str, NeotestResult] = {} results: Dict[str, NeotestResult] = {}
pytest_config: "Config" pytest_config: "Config"
from _pytest._code.code import ExceptionChainRepr from _pytest._code.code import ExceptionChainRepr
@@ -52,7 +54,7 @@ class PytestNeotestAdapter(NeotestAdapter):
file_path, *name_path = report.nodeid.split("::") file_path, *name_path = report.nodeid.split("::")
abs_path = str(Path(pytest_config.rootpath, file_path)) abs_path = str(Path(pytest_config.rootpath, file_path))
test_name, *namespaces = reversed(name_path) test_name, *namespaces = reversed(name_path)
valid_test_name, *_ = test_name.split("[") # ] valid_test_name, *params = test_name.split("[") # ]
errors: List[NeotestError] = [] errors: List[NeotestError] = []
short = self.get_short_output(pytest_config, report) short = self.get_short_output(pytest_config, report)
@@ -75,9 +77,11 @@ class PytestNeotestAdapter(NeotestAdapter):
errors.append({"message": error_message, "line": error_line}) errors.append({"message": error_message, "line": error_line})
else: else:
# TODO: Figure out how these are returned and how to represent # TODO: Figure out how these are returned and how to represent
raise Exception("Unhandled error type, please report to neotest-python repo") raise Exception(
"Unhandled error type, please report to neotest-python repo"
)
pos_id = "::".join([abs_path, *namespaces, valid_test_name]) pos_id = "::".join([abs_path, *namespaces, valid_test_name])
results[pos_id] = self.update_result( result = self.update_result(
results.get(pos_id), results.get(pos_id),
{ {
"short": short, "short": short,
@@ -85,6 +89,9 @@ class PytestNeotestAdapter(NeotestAdapter):
"errors": errors, "errors": errors,
}, },
) )
if not params:
stream(pos_id, result)
results[pos_id] = result
import pytest import pytest

View File

@@ -5,7 +5,7 @@ import traceback
import unittest import unittest
from pathlib import Path from pathlib import Path
from types import TracebackType from types import TracebackType
from typing import Any, Dict, Iterator, List, Tuple from typing import Any, Dict, List, Tuple
from unittest import TestCase, TestResult, TestSuite from unittest import TestCase, TestResult, TestSuite
from unittest.runner import TextTestResult, TextTestRunner from unittest.runner import TextTestResult, TextTestRunner
@@ -42,7 +42,8 @@ class UnittestNeotestAdapter(NeotestAdapter):
relative_dotted = relative_stem.replace(os.sep, ".") relative_dotted = relative_stem.replace(os.sep, ".")
return [".".join([relative_dotted, *child_ids])] return [".".join([relative_dotted, *child_ids])]
def run(self, args: List[str]) -> Dict: # TODO: Stream results
def run(self, args: List[str], _) -> Dict:
results = {} results = {}
errs: Dict[str, Tuple[Exception, Any, TracebackType]] = {} errs: Dict[str, Tuple[Exception, Any, TracebackType]] = {}