From 120331f85178f2a061703cbe9ec0538de91cb1e3 Mon Sep 17 00:00:00 2001 From: Suzy Wang Date: Wed, 10 Apr 2024 18:43:01 -0700 Subject: [PATCH] pushing down primary key predicates to proj optimization indexhint() condition --- src/Interpreters/InterpreterSelectQuery.cpp | 83 +++++++++++++++++-- src/Interpreters/InterpreterSelectQuery.h | 5 +- .../02923_projection_query_optimize.reference | 26 +++++- .../02923_projection_query_optimize.sql | 5 +- 4 files changed, 107 insertions(+), 12 deletions(-) diff --git a/src/Interpreters/InterpreterSelectQuery.cpp b/src/Interpreters/InterpreterSelectQuery.cpp index 89d16dd0b807..2be3a7448db0 100644 --- a/src/Interpreters/InterpreterSelectQuery.cpp +++ b/src/Interpreters/InterpreterSelectQuery.cpp @@ -2142,12 +2142,14 @@ ASTPtr InterpreterSelectQuery::pkOptimization( } } - return analyze_where_ast(where_ast, proj_pks, primary_keys); + ASTs primary_key_predicates; + return analyze_where_ast(where_ast, proj_pks, primary_key_predicates, primary_keys); } ASTPtr InterpreterSelectQuery::analyze_where_ast( const ASTPtr & ast, NameSet & proj_pks, + ASTs & primary_key_predicates, const Names & primary_keys) const { bool contains_pk = false; @@ -2168,7 +2170,7 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast( if (proj_pks.contains(col_name) && !contains_pk) { - ASTPtr rewrite_ast = create_proj_optimized_ast(ast, primary_keys); + ASTPtr rewrite_ast = create_proj_optimized_ast(ast, primary_key_predicates, primary_keys); auto and_func = makeASTFunction("and", std::move(rewrite_ast), ast->clone()); return and_func; } @@ -2207,19 +2209,31 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast( if (proj_pks_contains && !contains_pk) { - rewrite_ast = create_proj_optimized_ast(ast, primary_keys); + rewrite_ast = create_proj_optimized_ast(ast, primary_key_predicates, primary_keys); auto and_func = makeASTFunction("and", std::move(rewrite_ast), ast->clone()); return and_func; } } } - else if (ast_function_node->name == "and" || ast_function_node->name == "or") + else if (ast_function_node->name == "and") { + findPrimaryKeyPredicates(ast, primary_key_predicates, primary_keys); auto current_func = makeASTFunction(ast_function_node->name); for (size_t i = 0; i < arg_size; i++) { auto argument = ast_function_node->arguments->children[i]; - auto new_ast = analyze_where_ast(argument, proj_pks, primary_keys); + auto new_ast = analyze_where_ast(argument, proj_pks, primary_key_predicates, primary_keys); + current_func->arguments->children.push_back(std::move(new_ast)); + } + return current_func; + } + else if (ast_function_node->name == "or") + { + auto current_func = makeASTFunction(ast_function_node->name); + for (size_t i = 0; i < arg_size; i++) + { + auto argument = ast_function_node->arguments->children[i]; + auto new_ast = analyze_where_ast(argument, proj_pks, primary_key_predicates, primary_keys); current_func->arguments->children.push_back(std::move(new_ast)); } return current_func; @@ -2243,7 +2257,7 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast( * The following code will convert this select query to the following * select * from test_a where src in (select src from test_a where dst='-42') and dst='-42'; */ -ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, const Names & primary_keys) const +ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, ASTs & primary_key_predicates, const Names & primary_keys) const { auto select_query = std::make_shared(); select_query->setExpression(ASTSelectQuery::Expression::SELECT, std::make_shared()); @@ -2262,7 +2276,19 @@ ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, con auto tables_in_select = std::make_shared(); tables_in_select->children.push_back(std::move(tables_elem)); select_query->setExpression(ASTSelectQuery::Expression::TABLES, tables_in_select); - select_query->setExpression(ASTSelectQuery::Expression::WHERE, ast->clone()); + + if (primary_key_predicates.size() >=1) + { + auto new_where_predicates = makeASTFunction("and"); + for (auto predicates : primary_key_predicates) + new_where_predicates->arguments->children.push_back(predicates); + new_where_predicates->arguments->children.push_back(ast->clone()); + select_query->setExpression(ASTSelectQuery::Expression::WHERE, std::move(new_where_predicates)); + } + else + { + select_query->setExpression(ASTSelectQuery::Expression::WHERE, ast->clone()); + } select_with_union_query->list_of_selects->children.push_back(select_query); select_with_union_query->children.push_back(select_with_union_query->list_of_selects); @@ -2291,6 +2317,47 @@ ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, con return makeASTFunction("indexHint", std::move(in_function)); } +void InterpreterSelectQuery::findPrimaryKeyPredicates(const ASTPtr & where_predicate, ASTs & primary_key_predicates, const Names & primary_keys) const +{ + auto func = where_predicate->as(); + if (!func) + return; + + const static std::unordered_set supported_predicates_relations = { + "equals", + "notEquals", + "less", + "greater", + "lessOrEquals", + "greaterOrEquals", + }; + + auto arg_size = func->arguments ? func->arguments->children.size() : 0; + if (supported_predicates_relations.contains(func->name) && arg_size == 2) + { + auto lhs_argument = func->arguments->children.at(0); + auto rhs_argument = func->arguments->children.at(1); + String lhs = getIdentifier(lhs_argument); + String rhs = getIdentifier(rhs_argument); + auto col_name = (!lhs.empty()) ? lhs:rhs; + bool contains_pk = false; + if (std::find(primary_keys.begin(), primary_keys.end(), col_name) != primary_keys.end()) + contains_pk = true; + if (contains_pk) + { + primary_key_predicates.push_back(where_predicate->clone()); + } + + } + else if (func->name == "and") + { + for (size_t i = 0; i < arg_size; i++) + { + findPrimaryKeyPredicates(func->arguments->children.at(i), primary_key_predicates, primary_keys); + } + } +} + /// Note that this is const and accepts the analysis ref to be able to use it to do analysis for parallel replicas /// without affecting the final analysis multiple times void InterpreterSelectQuery::applyFiltersToPrewhereInAnalysis(ExpressionAnalysisResult & analysis) const @@ -3494,7 +3561,7 @@ String InterpreterSelectQuery::getIdentifier(ASTPtr & argument) const { if (const auto * id = argument->as()) return id->name(); - else if (argument->as()) + else if (argument->as() || argument->children.size() == 0) return ""; else return getIdentifier(argument->children.at(0)); diff --git a/src/Interpreters/InterpreterSelectQuery.h b/src/Interpreters/InterpreterSelectQuery.h index 24374152a5d4..46b666463790 100644 --- a/src/Interpreters/InterpreterSelectQuery.h +++ b/src/Interpreters/InterpreterSelectQuery.h @@ -165,9 +165,10 @@ class InterpreterSelectQuery : public IInterpreterUnionOrSelectQuery ASTSelectQuery & getSelectQuery() { return query_ptr->as(); } ASTPtr pkOptimization(const ProjectionsDescription & projections, const ASTPtr & where_ast, const Names & primary_keys) const; - ASTPtr create_proj_optimized_ast(const ASTPtr & ast, const Names & primary_keys) const; + ASTPtr create_proj_optimized_ast(const ASTPtr & ast, ASTs & primary_key_predicates, const Names & primary_keys) const; - ASTPtr analyze_where_ast(const ASTPtr & ast, NameSet & proj_pks, const Names & primary_keys) const; + ASTPtr analyze_where_ast(const ASTPtr & ast, NameSet & proj_pks, ASTs & primary_key_predicates, const Names & primary_keys) const; + void findPrimaryKeyPredicates(const ASTPtr & where_predicate, ASTs & primary_key_predicates, const Names & primary_keys) const; void addPrewhereAliasActions(); void applyFiltersToPrewhereInAnalysis(ExpressionAnalysisResult & analysis) const; bool shouldMoveToPrewhere() const; diff --git a/tests/queries/0_stateless/02923_projection_query_optimize.reference b/tests/queries/0_stateless/02923_projection_query_optimize.reference index 6fff661102b5..1c0cca0a35df 100644 --- a/tests/queries/0_stateless/02923_projection_query_optimize.reference +++ b/tests/queries/0_stateless/02923_projection_query_optimize.reference @@ -57,12 +57,21 @@ WHERE (indexHint(src IN ( )) AND (lower(id) = \'43\')) AND (id2 = \'44\') SELECT src2 FROM test_c +WHERE (src = \'44\') AND (indexHint((src, src2) IN ( + SELECT + src, + src2 + FROM projection_optimization_test.test_c + WHERE (src = \'44\') AND (src2 = \'44\') AND (dst = \'-42\') +)) AND (dst = \'-42\')) AND (src2 = \'44\') +SELECT src2 +FROM test_c WHERE ((src = \'44\') OR (indexHint((src, src2) IN ( SELECT src, src2 FROM projection_optimization_test.test_c - WHERE dst = \'-42\' + WHERE (src2 = \'44\') AND (dst = \'-42\') )) AND (dst = \'-42\'))) AND (src2 = \'44\') WITH (\'-41\', \'-42\', \'-43\') AS dst_list SELECT src @@ -99,6 +108,21 @@ WHERE (indexHint((src, src2) IN ( FROM projection_optimization_test.test_c WHERE c2 = \'41\' )) AND (c2 = \'41\'))) +SELECT src +FROM test_c +WHERE (src = \'40\') AND ((indexHint((src, src2) IN ( + SELECT + src, + src2 + FROM projection_optimization_test.test_c + WHERE (src = \'40\') AND (c1 = \'39\') +)) AND (c1 = \'39\')) OR (indexHint((src, src2) IN ( + SELECT + src, + src2 + FROM projection_optimization_test.test_c + WHERE (src = \'40\') AND (c2 = \'41\') +)) AND (c2 = \'41\'))) 40 21 40 diff --git a/tests/queries/0_stateless/02923_projection_query_optimize.sql b/tests/queries/0_stateless/02923_projection_query_optimize.sql index 061982c5e114..da96f19af40d 100644 --- a/tests/queries/0_stateless/02923_projection_query_optimize.sql +++ b/tests/queries/0_stateless/02923_projection_query_optimize.sql @@ -23,13 +23,16 @@ EXPLAIN SYNTAX select c1 as id, c2 as id2 from test_b where lower(id) = '43' and CREATE TABLE test_c(src String, src2 String, dst String, c1 String, c2 String, c3 String, other_cols String, PROJECTION p1(SELECT src, dst ORDER BY dst), PROJECTION p2(SELECT src, c1, c2, c3 ORDER BY c1), PROJECTION p3(SELECT src, c1, c2, c3 ORDER BY c2), PROJECTION p4(SELECT src, src2 ORDER BY c3)) ENGINE = MergeTree ORDER BY (src, src2); insert into test_c select number, number, -number, number-1, number+1, number+2, 'other_col '||toString(number) from numbers(100); +EXPLAIN SYNTAX select src2 from test_c where src = '44' and dst = '-42' and src2 = '44'; EXPLAIN SYNTAX select src2 from test_c where (src = '44' or dst = '-42') and src2 = '44'; + EXPLAIN SYNTAX WITH ('-41', '-42', '-43') AS dst_list select src from test_c where dst in dst_list or src = '20'; WITH ('-41', '-42', '-43') AS dst_list select src from test_c where dst in dst_list or src = '20'; EXPLAIN SYNTAX select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41'); -select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41'); +EXPLAIN SYNTAX select src from test_c where src = '40' and (c1 = '39' or c2 = '41'); +select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41'); select src from test_c where (c1 = '20' or c2 = '41'); select src from test_c where dst = '-21' and (c1 = '20' or c2 = '41');