Skip to content

Commit

Permalink
pushing down primary key predicates to proj optimization indexhint() …
Browse files Browse the repository at this point in the history
…condition
  • Loading branch information
SuzyWangIBMer committed Apr 11, 2024
1 parent 6cc0d04 commit 120331f
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 12 deletions.
83 changes: 75 additions & 8 deletions src/Interpreters/InterpreterSelectQuery.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2142,12 +2142,14 @@ ASTPtr InterpreterSelectQuery::pkOptimization(
}
}

return analyze_where_ast(where_ast, proj_pks, primary_keys);
ASTs primary_key_predicates;
return analyze_where_ast(where_ast, proj_pks, primary_key_predicates, primary_keys);
}

ASTPtr InterpreterSelectQuery::analyze_where_ast(
const ASTPtr & ast,
NameSet & proj_pks,
ASTs & primary_key_predicates,
const Names & primary_keys) const
{
bool contains_pk = false;
Expand All @@ -2168,7 +2170,7 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast(

if (proj_pks.contains(col_name) && !contains_pk)
{
ASTPtr rewrite_ast = create_proj_optimized_ast(ast, primary_keys);
ASTPtr rewrite_ast = create_proj_optimized_ast(ast, primary_key_predicates, primary_keys);
auto and_func = makeASTFunction("and", std::move(rewrite_ast), ast->clone());
return and_func;
}
Expand Down Expand Up @@ -2207,19 +2209,31 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast(

if (proj_pks_contains && !contains_pk)
{
rewrite_ast = create_proj_optimized_ast(ast, primary_keys);
rewrite_ast = create_proj_optimized_ast(ast, primary_key_predicates, primary_keys);
auto and_func = makeASTFunction("and", std::move(rewrite_ast), ast->clone());
return and_func;
}
}
}
else if (ast_function_node->name == "and" || ast_function_node->name == "or")
else if (ast_function_node->name == "and")
{
findPrimaryKeyPredicates(ast, primary_key_predicates, primary_keys);
auto current_func = makeASTFunction(ast_function_node->name);
for (size_t i = 0; i < arg_size; i++)
{
auto argument = ast_function_node->arguments->children[i];
auto new_ast = analyze_where_ast(argument, proj_pks, primary_keys);
auto new_ast = analyze_where_ast(argument, proj_pks, primary_key_predicates, primary_keys);
current_func->arguments->children.push_back(std::move(new_ast));
}
return current_func;
}
else if (ast_function_node->name == "or")
{
auto current_func = makeASTFunction(ast_function_node->name);
for (size_t i = 0; i < arg_size; i++)
{
auto argument = ast_function_node->arguments->children[i];
auto new_ast = analyze_where_ast(argument, proj_pks, primary_key_predicates, primary_keys);
current_func->arguments->children.push_back(std::move(new_ast));
}
return current_func;
Expand All @@ -2243,7 +2257,7 @@ ASTPtr InterpreterSelectQuery::analyze_where_ast(
* The following code will convert this select query to the following
* select * from test_a where src in (select src from test_a where dst='-42') and dst='-42';
*/
ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, const Names & primary_keys) const
ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, ASTs & primary_key_predicates, const Names & primary_keys) const
{
auto select_query = std::make_shared<ASTSelectQuery>();
select_query->setExpression(ASTSelectQuery::Expression::SELECT, std::make_shared<ASTExpressionList>());
Expand All @@ -2262,7 +2276,19 @@ ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, con
auto tables_in_select = std::make_shared<ASTTablesInSelectQuery>();
tables_in_select->children.push_back(std::move(tables_elem));
select_query->setExpression(ASTSelectQuery::Expression::TABLES, tables_in_select);
select_query->setExpression(ASTSelectQuery::Expression::WHERE, ast->clone());

if (primary_key_predicates.size() >=1)
{
auto new_where_predicates = makeASTFunction("and");
for (auto predicates : primary_key_predicates)
new_where_predicates->arguments->children.push_back(predicates);
new_where_predicates->arguments->children.push_back(ast->clone());
select_query->setExpression(ASTSelectQuery::Expression::WHERE, std::move(new_where_predicates));
}
else
{
select_query->setExpression(ASTSelectQuery::Expression::WHERE, ast->clone());
}

select_with_union_query->list_of_selects->children.push_back(select_query);
select_with_union_query->children.push_back(select_with_union_query->list_of_selects);
Expand Down Expand Up @@ -2291,6 +2317,47 @@ ASTPtr InterpreterSelectQuery::create_proj_optimized_ast(const ASTPtr & ast, con
return makeASTFunction("indexHint", std::move(in_function));
}

void InterpreterSelectQuery::findPrimaryKeyPredicates(const ASTPtr & where_predicate, ASTs & primary_key_predicates, const Names & primary_keys) const
{
auto func = where_predicate->as<ASTFunction>();
if (!func)
return;

const static std::unordered_set<String> supported_predicates_relations = {
"equals",
"notEquals",
"less",
"greater",
"lessOrEquals",
"greaterOrEquals",
};

auto arg_size = func->arguments ? func->arguments->children.size() : 0;
if (supported_predicates_relations.contains(func->name) && arg_size == 2)
{
auto lhs_argument = func->arguments->children.at(0);
auto rhs_argument = func->arguments->children.at(1);
String lhs = getIdentifier(lhs_argument);
String rhs = getIdentifier(rhs_argument);
auto col_name = (!lhs.empty()) ? lhs:rhs;
bool contains_pk = false;
if (std::find(primary_keys.begin(), primary_keys.end(), col_name) != primary_keys.end())
contains_pk = true;
if (contains_pk)
{
primary_key_predicates.push_back(where_predicate->clone());
}

}
else if (func->name == "and")
{
for (size_t i = 0; i < arg_size; i++)
{
findPrimaryKeyPredicates(func->arguments->children.at(i), primary_key_predicates, primary_keys);
}
}
}

/// Note that this is const and accepts the analysis ref to be able to use it to do analysis for parallel replicas
/// without affecting the final analysis multiple times
void InterpreterSelectQuery::applyFiltersToPrewhereInAnalysis(ExpressionAnalysisResult & analysis) const
Expand Down Expand Up @@ -3494,7 +3561,7 @@ String InterpreterSelectQuery::getIdentifier(ASTPtr & argument) const
{
if (const auto * id = argument->as<ASTIdentifier>())
return id->name();
else if (argument->as<ASTLiteral>())
else if (argument->as<ASTLiteral>() || argument->children.size() == 0)
return "";
else
return getIdentifier(argument->children.at(0));
Expand Down
5 changes: 3 additions & 2 deletions src/Interpreters/InterpreterSelectQuery.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,10 @@ class InterpreterSelectQuery : public IInterpreterUnionOrSelectQuery
ASTSelectQuery & getSelectQuery() { return query_ptr->as<ASTSelectQuery &>(); }

ASTPtr pkOptimization(const ProjectionsDescription & projections, const ASTPtr & where_ast, const Names & primary_keys) const;
ASTPtr create_proj_optimized_ast(const ASTPtr & ast, const Names & primary_keys) const;
ASTPtr create_proj_optimized_ast(const ASTPtr & ast, ASTs & primary_key_predicates, const Names & primary_keys) const;

ASTPtr analyze_where_ast(const ASTPtr & ast, NameSet & proj_pks, const Names & primary_keys) const;
ASTPtr analyze_where_ast(const ASTPtr & ast, NameSet & proj_pks, ASTs & primary_key_predicates, const Names & primary_keys) const;
void findPrimaryKeyPredicates(const ASTPtr & where_predicate, ASTs & primary_key_predicates, const Names & primary_keys) const;
void addPrewhereAliasActions();
void applyFiltersToPrewhereInAnalysis(ExpressionAnalysisResult & analysis) const;
bool shouldMoveToPrewhere() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,21 @@ WHERE (indexHint(src IN (
)) AND (lower(id) = \'43\')) AND (id2 = \'44\')
SELECT src2
FROM test_c
WHERE (src = \'44\') AND (indexHint((src, src2) IN (
SELECT
src,
src2
FROM projection_optimization_test.test_c
WHERE (src = \'44\') AND (src2 = \'44\') AND (dst = \'-42\')
)) AND (dst = \'-42\')) AND (src2 = \'44\')
SELECT src2
FROM test_c
WHERE ((src = \'44\') OR (indexHint((src, src2) IN (
SELECT
src,
src2
FROM projection_optimization_test.test_c
WHERE dst = \'-42\'
WHERE (src2 = \'44\') AND (dst = \'-42\')
)) AND (dst = \'-42\'))) AND (src2 = \'44\')
WITH (\'-41\', \'-42\', \'-43\') AS dst_list
SELECT src
Expand Down Expand Up @@ -99,6 +108,21 @@ WHERE (indexHint((src, src2) IN (
FROM projection_optimization_test.test_c
WHERE c2 = \'41\'
)) AND (c2 = \'41\')))
SELECT src
FROM test_c
WHERE (src = \'40\') AND ((indexHint((src, src2) IN (
SELECT
src,
src2
FROM projection_optimization_test.test_c
WHERE (src = \'40\') AND (c1 = \'39\')
)) AND (c1 = \'39\')) OR (indexHint((src, src2) IN (
SELECT
src,
src2
FROM projection_optimization_test.test_c
WHERE (src = \'40\') AND (c2 = \'41\')
)) AND (c2 = \'41\')))
40
21
40
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ EXPLAIN SYNTAX select c1 as id, c2 as id2 from test_b where lower(id) = '43' and

CREATE TABLE test_c(src String, src2 String, dst String, c1 String, c2 String, c3 String, other_cols String, PROJECTION p1(SELECT src, dst ORDER BY dst), PROJECTION p2(SELECT src, c1, c2, c3 ORDER BY c1), PROJECTION p3(SELECT src, c1, c2, c3 ORDER BY c2), PROJECTION p4(SELECT src, src2 ORDER BY c3)) ENGINE = MergeTree ORDER BY (src, src2);
insert into test_c select number, number, -number, number-1, number+1, number+2, 'other_col '||toString(number) from numbers(100);
EXPLAIN SYNTAX select src2 from test_c where src = '44' and dst = '-42' and src2 = '44';
EXPLAIN SYNTAX select src2 from test_c where (src = '44' or dst = '-42') and src2 = '44';

EXPLAIN SYNTAX WITH ('-41', '-42', '-43') AS dst_list select src from test_c where dst in dst_list or src = '20';
WITH ('-41', '-42', '-43') AS dst_list select src from test_c where dst in dst_list or src = '20';

EXPLAIN SYNTAX select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41');
select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41');
EXPLAIN SYNTAX select src from test_c where src = '40' and (c1 = '39' or c2 = '41');

select src from test_c where c3 = '42' and (c1 = '39' or c2 = '41');
select src from test_c where (c1 = '20' or c2 = '41');
select src from test_c where dst = '-21' and (c1 = '20' or c2 = '41');

Expand Down

0 comments on commit 120331f

Please sign in to comment.