Skip to content

Commit

Permalink
Refactor QueryValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocapozzoli committed Oct 17, 2024
1 parent c3efc39 commit 1c20eed
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 18 deletions.
39 changes: 37 additions & 2 deletions das-query-engine/tests/integration/handle/test_query_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
BaseTestHandlerAction,
expression,
inheritance,
similarity,
human,
mammal,
symbol,
)
Expand Down Expand Up @@ -33,12 +35,45 @@ def valid_event(self, action_type):
},
}
}

@pytest.fixture
def query_list(self, action_type):
return {
"body": {
"action": action_type,
"input": {
"query": [
{
"atom_type": "link",
"type": expression,
"targets": [
{"atom_type": "node", "type": symbol, "name": inheritance},
{"atom_type": "variable", "name": "$v1"},
{"atom_type": "node", "type": symbol, "name": mammal},
],
},
{
"atom_type": "link",
"type": expression,
"targets": [
{"atom_type": "node", "type": symbol, "name": similarity},
{"atom_type": "variable", "name": "$v1"},
{"atom_type": "node", "type": symbol, "name": human},
],
}
]
},
}
}

@pytest.mark.parametrize("query_input", ["valid_event", "query_list"])
def test_query_action(
self,
valid_event,
request,
query_input,
):
body, status_code = self.make_request(valid_event)
query_data = request.getfixturevalue(query_input)
body, status_code = self.make_request(query_data)
expected_status_code = 200

assert (
Expand Down
35 changes: 19 additions & 16 deletions das-query-engine/validators/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,28 +53,31 @@ class QueryValidator(PayloadValidator):
strict = True

@staticmethod
def validate_query(query, *args, **kwargs) -> bool:
if not isinstance(query, dict):
def validate_query(queries, *args, **kwargs) -> bool:
if isinstance(queries, dict):
queries = [queries]
elif not isinstance(queries, list) or not all(isinstance(query, dict) for query in queries):
return False

for query in queries:
atom_type = query.get("atom_type")
query_type = query.get("type")
name = query.get("name")
targets = query.get("targets")

atom_type = query.get("atom_type")
query_type = query.get("type")
name = query.get("name")
targets = query.get("targets")
if atom_type not in ["node", "link"]:
return False

if atom_type not in ["node", "link"]:
return False
if not isinstance(query_type, str):
return False

if not isinstance(query_type, str):
return False
if atom_type == "node" and not isinstance(name, str):
return False

if atom_type == "node" and not isinstance(name, str):
return False

if atom_type == "link" and not isinstance(targets, list):
return False
if atom_type == "link" and not isinstance(targets, list):
return False

return True
return True

query = datatypes.Function(validate_query, required=True)

Expand Down

0 comments on commit 1c20eed

Please sign in to comment.