From 1c20eedbc34915e6947ca8142e47c33bf0161b52 Mon Sep 17 00:00:00 2001 From: marcocapozzoli Date: Thu, 17 Oct 2024 10:38:32 -0300 Subject: [PATCH] Refactor QueryValidator --- .../integration/handle/test_query_action.py | 39 ++++++++++++++++++- das-query-engine/validators/actions.py | 35 +++++++++-------- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/das-query-engine/tests/integration/handle/test_query_action.py b/das-query-engine/tests/integration/handle/test_query_action.py index 528ab3d..0a09afe 100644 --- a/das-query-engine/tests/integration/handle/test_query_action.py +++ b/das-query-engine/tests/integration/handle/test_query_action.py @@ -5,6 +5,8 @@ BaseTestHandlerAction, expression, inheritance, + similarity, + human, mammal, symbol, ) @@ -33,12 +35,45 @@ def valid_event(self, action_type): }, } } + + @pytest.fixture + def query_list(self, action_type): + return { + "body": { + "action": action_type, + "input": { + "query": [ + { + "atom_type": "link", + "type": expression, + "targets": [ + {"atom_type": "node", "type": symbol, "name": inheritance}, + {"atom_type": "variable", "name": "$v1"}, + {"atom_type": "node", "type": symbol, "name": mammal}, + ], + }, + { + "atom_type": "link", + "type": expression, + "targets": [ + {"atom_type": "node", "type": symbol, "name": similarity}, + {"atom_type": "variable", "name": "$v1"}, + {"atom_type": "node", "type": symbol, "name": human}, + ], + } + ] + }, + } + } + @pytest.mark.parametrize("query_input", ["valid_event", "query_list"]) def test_query_action( self, - valid_event, + request, + query_input, ): - body, status_code = self.make_request(valid_event) + query_data = request.getfixturevalue(query_input) + body, status_code = self.make_request(query_data) expected_status_code = 200 assert ( diff --git a/das-query-engine/validators/actions.py b/das-query-engine/validators/actions.py index 78839d5..af15ac3 100644 --- a/das-query-engine/validators/actions.py +++ b/das-query-engine/validators/actions.py @@ -53,28 +53,31 @@ class QueryValidator(PayloadValidator): strict = True @staticmethod - def validate_query(query, *args, **kwargs) -> bool: - if not isinstance(query, dict): + def validate_query(queries, *args, **kwargs) -> bool: + if isinstance(queries, dict): + queries = [queries] + elif not isinstance(queries, list) or not all(isinstance(query, dict) for query in queries): return False + + for query in queries: + atom_type = query.get("atom_type") + query_type = query.get("type") + name = query.get("name") + targets = query.get("targets") - atom_type = query.get("atom_type") - query_type = query.get("type") - name = query.get("name") - targets = query.get("targets") + if atom_type not in ["node", "link"]: + return False - if atom_type not in ["node", "link"]: - return False + if not isinstance(query_type, str): + return False - if not isinstance(query_type, str): - return False + if atom_type == "node" and not isinstance(name, str): + return False - if atom_type == "node" and not isinstance(name, str): - return False - - if atom_type == "link" and not isinstance(targets, list): - return False + if atom_type == "link" and not isinstance(targets, list): + return False - return True + return True query = datatypes.Function(validate_query, required=True)