feat: streamed results

This commit is contained in:
Rónán Carrigan
2022-07-17 22:22:29 +01:00
parent 023b7bda64
commit aaf83100b6
4 changed files with 52 additions and 8 deletions

View File

@@ -92,12 +92,20 @@ end
function PythonNeotestAdapter.build_spec(args)
local position = args.tree:data()
local results_path = async.fn.tempname()
local stream_path = async.fn.tempname()
local x = io.open(stream_path, "w")
x:write("")
x:close()
local root = PythonNeotestAdapter.root(position.path)
local python = base.get_python_command(root)
local runner = get_runner(python)
local stream_data, stop_stream = lib.files.stream_lines(stream_path)
local script_args = vim.tbl_flatten({
"--results-file",
results_path,
"--stream-file",
stream_path,
"--runner",
runner,
"--",
@@ -112,11 +120,24 @@ function PythonNeotestAdapter.build_spec(args)
script_args,
})
local strategy_config = get_strategy_config(args.strategy, python, python_script, script_args)
---@type neotest.RunSpec
return {
command = command,
context = {
results_path = results_path,
stop_stream = stop_stream,
},
stream = function()
return function()
local lines = stream_data()
local results = {}
for _, line in ipairs(lines) do
local result = vim.json.decode(line, { luanil = { object = true } })
results[result.id] = result.result
end
return results
end
end,
strategy = strategy_config,
}
end
@@ -126,6 +147,7 @@ end
---@param result neotest.StrategyResult
---@return neotest.Result[]
function PythonNeotestAdapter.results(spec, result)
spec.context.stop_stream()
local success, data = pcall(lib.files.read, spec.context.results_path)
if not success then
data = "{}"

View File

@@ -3,6 +3,8 @@ import json
from enum import Enum
from typing import List
from neotest_python.base import NeotestResult
class TestRunner(str, Enum):
PYTEST = "pytest"
@@ -29,12 +31,24 @@ parser.add_argument(
required=True,
help="File to store result JSON in",
)
parser.add_argument(
"--stream-file",
dest="stream_file",
required=True,
help="File to stream result JSON to",
)
parser.add_argument("args", nargs="*")
def main(argv: List[str]):
args = parser.parse_args(argv)
adapter = get_adapter(TestRunner(args.runner))
results = adapter.run(args.args)
with open(args.stream_file, "w") as stream_file:
def stream(pos_id: str, result: NeotestResult):
stream_file.write(json.dumps({"id": pos_id, "result": result}) + "\n")
stream_file.flush()
results = adapter.run(args.args, stream)
with open(args.results_file, "w") as results_file:
json.dump(results, results_file)

View File

@@ -1,6 +1,6 @@
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
from .base import NeotestAdapter, NeotestError, NeotestResult, NeotestResultStatus
@@ -32,7 +32,9 @@ class PytestNeotestAdapter(NeotestAdapter):
buffer.seek(0)
return buffer.read()
def run(self, args: List[str]) -> Dict[str, NeotestResult]:
def run(
self, args: List[str], stream: Callable[[str, NeotestResult], None]
) -> Dict[str, NeotestResult]:
results: Dict[str, NeotestResult] = {}
pytest_config: "Config"
from _pytest._code.code import ExceptionChainRepr
@@ -52,7 +54,7 @@ class PytestNeotestAdapter(NeotestAdapter):
file_path, *name_path = report.nodeid.split("::")
abs_path = str(Path(pytest_config.rootpath, file_path))
test_name, *namespaces = reversed(name_path)
valid_test_name, *_ = test_name.split("[") # ]
valid_test_name, *params = test_name.split("[") # ]
errors: List[NeotestError] = []
short = self.get_short_output(pytest_config, report)
@@ -75,9 +77,11 @@ class PytestNeotestAdapter(NeotestAdapter):
errors.append({"message": error_message, "line": error_line})
else:
# TODO: Figure out how these are returned and how to represent
raise Exception("Unhandled error type, please report to neotest-python repo")
raise Exception(
"Unhandled error type, please report to neotest-python repo"
)
pos_id = "::".join([abs_path, *namespaces, valid_test_name])
results[pos_id] = self.update_result(
result = self.update_result(
results.get(pos_id),
{
"short": short,
@@ -85,6 +89,9 @@ class PytestNeotestAdapter(NeotestAdapter):
"errors": errors,
},
)
if not params:
stream(pos_id, result)
results[pos_id] = result
import pytest

View File

@@ -5,7 +5,7 @@ import traceback
import unittest
from pathlib import Path
from types import TracebackType
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, List, Tuple
from unittest import TestCase, TestResult, TestSuite
from unittest.runner import TextTestResult, TextTestRunner
@@ -42,7 +42,8 @@ class UnittestNeotestAdapter(NeotestAdapter):
relative_dotted = relative_stem.replace(os.sep, ".")
return [".".join([relative_dotted, *child_ids])]
def run(self, args: List[str]) -> Dict:
# TODO: Stream results
def run(self, args: List[str], _) -> Dict:
results = {}
errs: Dict[str, Tuple[Exception, Any, TracebackType]] = {}