Skip to content

Commit

Permalink
Type+PydanticModel more tool test code.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Jul 15, 2024
1 parent 600bcd1 commit 2472b44
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 35 deletions.
27 changes: 23 additions & 4 deletions lib/galaxy/tool_util/unittest_utils/interactor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from typing import (
Any,
Dict,
List,
Optional,
)

from galaxy.tool_util.verify.interactor import (
ToolTestCase,
ToolTestCaseList,
)

NEW_HISTORY = object()
NEW_HISTORY_ID = "new"
EXISTING_HISTORY = {"id": "existing"}
Expand Down Expand Up @@ -41,22 +53,29 @@ def get_tests_summary(self):
},
}

def get_tool_tests(self, tool_id, tool_version=None):
def get_tool_tests_model(self, tool_id, tool_version=None) -> ToolTestCaseList:
tool_dict = self.get_tests_summary().get(tool_id)
test_defs = []
for this_tool_version, version_defs in tool_dict.items():
if tool_version is not None and tool_version != "*" and this_tool_version != tool_version:
continue

count = version_defs["count"]
for _ in range(count):
for index in range(count):
test_def = {
"tool_id": tool_id,
"tool_version": this_tool_version or "0.1.1-default",
"name": tool_id,
"test_index": index,
"inputs": [],
"outputs": [],
}
test_defs.append(test_def)
test_defs.append(ToolTestCase(**test_def))

if tool_version is None or tool_version != "*":
break

return test_defs
return ToolTestCaseList(__root__=test_defs)

def get_tool_tests(self, tool_id: str, tool_version: Optional[str] = None) -> List[Dict[str, Any]]:
return [m.dict() for m in self.get_tool_tests_model(tool_id, tool_version).__root__]
103 changes: 77 additions & 26 deletions lib/galaxy/tool_util/verify/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)

from packaging.version import Version
from pydantic import BaseModel
from requests import Response
from requests.cookies import RequestsCookieJar
from typing_extensions import (
Expand Down Expand Up @@ -225,12 +226,15 @@ def get_tests_summary(self):
assert response.status_code == 200, f"Non 200 response from tool tests available API. [{response.content}]"
return response.json()

def get_tool_tests(self, tool_id: str, tool_version: Optional[str] = None) -> ToolTestDictsT:
def get_tool_tests_model(self, tool_id: str, tool_version: Optional[str] = None) -> "ToolTestCaseList":
url = f"tools/{tool_id}/test_data"
params = {"tool_version": tool_version} if tool_version else None
response = self._get(url, data=params)
assert response.status_code == 200, f"Non 200 response from tool test API. [{response.content}]"
return response.json()
return ToolTestCaseList(__root__=[ToolTestCase(**t) for t in response.json()])

def get_tool_tests(self, tool_id: str, tool_version: Optional[str] = None) -> ToolTestDictsT:
return [test_case_to_dict(m) for m in self.get_tool_tests_model(tool_id, tool_version).root]

def verify_output_collection(
self, output_collection_def, output_collection_id, history, tool_id, tool_version=None
Expand Down Expand Up @@ -1647,6 +1651,43 @@ class ToolTestDescriptionDict(TypedDict):
exception: Optional[str]


class Assertion(BaseModel):
tag: str
attributes: Dict[str, Any]
children: Optional[List[Dict[str, Any]]]


AssertionModelList = Optional[List[Assertion]]


class ToolTestCase(BaseModel):
inputs: Any
outputs: Any
output_collections: List[Dict[str, Any]] = []
stdout: AssertionModelList = []
stderr: AssertionModelList = []
expect_exit_code: Optional[int] = None
expect_failure: bool = False
expect_test_failure: bool = False
maxseconds: Optional[int] = None
num_outputs: Optional[int] = None
command_line: AssertionModelList = []
command_version: AssertionModelList = []
required_files: List[Any] = []
required_data_tables: List[Any] = []
required_loc_files: List[str] = []
error: bool = False
exception: Optional[str] = None
name: str
tool_id: str
tool_version: str
test_index: int


class ToolTestCaseList(BaseModel):
__root__: List[ToolTestCase]


class ToolTestDescription:
"""
Encapsulates information about a tool test, and allows creation of a
Expand Down Expand Up @@ -1683,7 +1724,6 @@ def __init__(self, processed_test_dict: ToolTestDict):
processed_test_dict = cast(InvalidToolTestDict, processed_test_dict)
maxseconds = DEFAULT_TOOL_TEST_WAIT
output_collections = []

self.test_index = test_index
assert (
"tool_id" in processed_test_dict
Expand Down Expand Up @@ -1726,36 +1766,47 @@ def test_data(self):
"""
return test_data_iter(self.required_files)

def to_dict(self) -> ToolTestDescriptionDict:
def to_model(self) -> ToolTestCase:
inputs_dict = {}
for key, value in self.inputs.items():
if hasattr(value, "to_dict"):
inputs_dict[key] = value.to_dict()
else:
inputs_dict[key] = value

return {
"inputs": inputs_dict,
"outputs": self.outputs,
"output_collections": [_.to_dict() for _ in self.output_collections],
"num_outputs": self.num_outputs,
"command_line": self.command_line,
"command_version": self.command_version,
"stdout": self.stdout,
"stderr": self.stderr,
"expect_exit_code": self.expect_exit_code,
"expect_failure": self.expect_failure,
"expect_test_failure": self.expect_test_failure,
"name": self.name,
"test_index": self.test_index,
"tool_id": self.tool_id,
"tool_version": self.tool_version,
"required_files": self.required_files,
"required_data_tables": self.required_data_tables,
"required_loc_files": self.required_loc_files,
"error": self.error,
"exception": self.exception,
}
return ToolTestCase(
**{
"inputs": inputs_dict,
"outputs": self.outputs,
"output_collections": [_.to_dict() for _ in self.output_collections],
"num_outputs": self.num_outputs,
"command_line": self.command_line,
"command_version": self.command_version,
"stdout": self.stdout,
"stderr": self.stderr,
"expect_exit_code": self.expect_exit_code,
"expect_failure": self.expect_failure,
"expect_test_failure": self.expect_test_failure,
"name": self.name,
"test_index": self.test_index,
"tool_id": self.tool_id,
"tool_version": self.tool_version,
"required_files": self.required_files,
"required_data_tables": self.required_data_tables,
"required_loc_files": self.required_loc_files,
"error": self.error,
"exception": self.exception,
}
)

def to_dict(self) -> ToolTestDescriptionDict:
# For backward compatibility maintain a dict version - if
# this comment got merged the converter tests failed without this.
return test_case_to_dict(self.to_model())


def test_case_to_dict(model: ToolTestCase) -> ToolTestDict:
return cast(ToolTestDict, model.dict())


def test_data_iter(required_files):
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/verify/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def build_case_references(
assert tool_id
tool_test_dicts: ToolTestDictsT = galaxy_interactor.get_tool_tests(tool_id, tool_version=tool_version)
for i, tool_test_dict in enumerate(tool_test_dicts):
this_tool_version = tool_test_dict.get("tool_version", tool_version)
this_tool_version = tool_test_dict.get("tool_version") or tool_version
this_test_index = i
if test_index == ALL_TESTS or i == test_index:
test_reference = TestReference(tool_id, this_tool_version, this_test_index)
Expand Down
14 changes: 10 additions & 4 deletions lib/galaxy/webapps/galaxy/api/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
FetchDataFormPayload,
FetchDataPayload,
)
from galaxy.tool_util.verify.interactor import (
ToolTestCase,
ToolTestCaseList,
)
from galaxy.tools import Tool
from galaxy.tools.evaluation import global_tool_errors
from galaxy.util.zipstream import ZipstreamWrapper
from galaxy.web import (
Expand Down Expand Up @@ -316,7 +321,7 @@ def tests_summary(self, trans: GalaxyWebTransaction, **kwd):
return test_counts_by_tool

@expose_api_anonymous_and_sessionless
def test_data(self, trans: GalaxyWebTransaction, id, **kwd):
def test_data(self, trans: GalaxyWebTransaction, id, **kwd) -> ToolTestCaseList:
"""
GET /api/tools/{tool_id}/test_data?tool_version={tool_version}
Expand All @@ -331,6 +336,7 @@ def test_data(self, trans: GalaxyWebTransaction, id, **kwd):
"""
kwd = _kwd_or_payload(kwd)
tool_version = kwd.get("tool_version", None)
tools: List[Tool]
if tool_version == "*":
tools = self.app.toolbox.get_tool(id, get_all_versions=True)
for tool in tools:
Expand All @@ -339,10 +345,10 @@ def test_data(self, trans: GalaxyWebTransaction, id, **kwd):
else:
tools = [self.service._get_tool(trans, id, tool_version=tool_version, user=trans.user)]

test_defs = []
test_defs: List[ToolTestCase] = []
for tool in tools:
test_defs.extend([t.to_dict() for t in tool.tests])
return test_defs
test_defs.extend([t.to_model() for t in tool.tests])
return ToolTestCaseList(test_defs)

@web.require_admin
@expose_api
Expand Down

0 comments on commit 2472b44

Please sign in to comment.