From 601d3a9dd1267ab8951a886b621342880fc14417 Mon Sep 17 00:00:00 2001 From: ovsds Date: Wed, 18 Dec 2024 17:39:10 +0100 Subject: [PATCH] fix: PR --- .../result/complex_queries/test_lookup_functions.py | 6 ++++-- lib/dl_formula/dl_formula/inspect/expression.py | 11 ----------- .../multi_query/splitters/query_fork.py | 12 ++++-------- 3 files changed, 8 insertions(+), 21 deletions(-) diff --git a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py index 906e0e702..2b842848c 100644 --- a/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py +++ b/lib/dl_api_lib/dl_api_lib_tests/db/data_api/result/complex_queries/test_lookup_functions.py @@ -86,9 +86,11 @@ def test_ago_filtered(self, control_api, data_api, saved_dataset): assert result_resp.status_code == HTTPStatus.OK, result_resp.json query: str = result_resp.json["blocks"][0]["query"] - expected_query_pattern = r"JOIN[\S\s]*\([\S\s]*order_date[\S\s]*5[\S\s]*>=[\S\s]*2014-01-06[\S\s]*\)[\S\s]*ON" + expected_query_pattern = r"JOIN.*\(.*order_date.*5.*>=.*2014-01-06.*\).*ON" assert re.search( - expected_query_pattern, query + expected_query_pattern, + query, + flags=re.DOTALL, ), "Expected to find pattern 'JOIN (... order_date >= 2014-01-06 ...) ON' in query" def test_ago_variants(self, control_api, data_api, saved_dataset): diff --git a/lib/dl_formula/dl_formula/inspect/expression.py b/lib/dl_formula/dl_formula/inspect/expression.py index c07ab6537..9433890f4 100644 --- a/lib/dl_formula/dl_formula/inspect/expression.py +++ b/lib/dl_formula/dl_formula/inspect/expression.py @@ -452,17 +452,6 @@ def contains_lookup_functions(node: nodes.FormulaItem) -> bool: return False -def contains_node(node: nodes.FormulaItem, target_node: nodes.FormulaItem) -> bool: - if node == target_node: - return True - - for child in autonomous_children(node): - if contains_node(child, target_node): - return True - - return False - - def resolve_dimensions( node_stack: Iterable[nodes.FormulaItem], dimensions: List[nodes.FormulaItem], # TODO: rename to global_dimensions diff --git a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py index 06c013bbb..04cd22050 100644 --- a/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py +++ b/lib/dl_query_processing/dl_query_processing/multi_query/splitters/query_fork.py @@ -11,7 +11,6 @@ import dl_formula.core.nodes as formula_nodes from dl_formula.inspect.env import InspectionEnvironment import dl_formula.inspect.expression as inspect_expression -from dl_formula.inspect.expression import contains_node import dl_formula.inspect.node as inspect_node from dl_formula.mutation.mutation import ( FormulaMutation, @@ -462,13 +461,10 @@ def get_split_masks( for filter_idx, filter_formula in enumerate(query.filters): if filter_formula.original_field_id in qfork_info.bfb_field_ids: # Filter field is in BFB, so exclude it unless it is mutated by one of the BFB mutations - if any( - contains_node(filter_formula.formula_obj, mutation.original) - for mutation in qfork_info.bfb_filter_mutations - ): - new_filter = filter_formula.clone( - formula_obj=apply_mutations(filter_formula.formula_obj, qfork_info.bfb_filter_mutations), - ) + new_formula_obj = apply_mutations(filter_formula.formula_obj, qfork_info.bfb_filter_mutations) + + if new_formula_obj is not filter_formula.formula_obj: + new_filter = filter_formula.clone(formula_obj=new_formula_obj) add_filters.append(new_filter) continue if filter_idx in split_filter_indices: