Skip to content

Commit

Permalink
Add an optimization that removes redundant equality
Browse files Browse the repository at this point in the history
checks on boolean functions. This fixes a bug in
which the primary index is not used for queries like
SELECT * FROM <table> WHERE <pk> in (<n>) = 1
  • Loading branch information
josh-hildred committed Apr 9, 2024
1 parent f36ae13 commit 9d4f1d8
Show file tree
Hide file tree
Showing 4 changed files with 194 additions and 0 deletions.
76 changes: 76 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ namespace ErrorCodes
extern const int LOGICAL_ERROR;
}

using namespace std::literals;
static constexpr std::array boolean_functions{
"equals"sv, "notEquals"sv, "less"sv, "greaterOrEquals"sv, "greater"sv, "lessOrEquals"sv, "in"sv, "notIn"sv,
"globalIn"sv, "globalNotIn"sv, "nullIn"sv, "notNullIn"sv, "globalNullIn"sv, "globalNullNotIn"sv, "isNull"sv, "isNotNull"sv,
"like"sv, "notLike"sv, "ilike"sv, "notILike"sv, "empty"sv, "notEmpty"sv, "not"sv, "and"sv,
"or"sv};

static bool isBooleanFunction(const String & func_name)
{
return std::any_of(
boolean_functions.begin(), boolean_functions.end(), [&](const auto boolean_func) { return func_name == boolean_func; });
}

/// Visitor that optimizes logical expressions _only_ in JOIN ON section
class JoinOnLogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithContext<JoinOnLogicalExpressionOptimizerVisitor>
{
Expand Down Expand Up @@ -253,6 +266,12 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
tryOptimizeAndEqualsNotEqualsChain(node);
return;
}

if (function_node->getFunctionName() == "equals")
{
tryOptimizeOutRedundantEquals(node);
return;
}
}

private:
Expand Down Expand Up @@ -552,6 +571,63 @@ class LogicalExpressionOptimizerVisitor : public InDepthQueryTreeVisitorWithCont
function_node.getArguments().getNodes() = std::move(or_operands);
function_node.resolveAsFunction(or_function_resolver);
}

void tryOptimizeOutRedundantEquals(QueryTreeNodePtr & node)
{
auto & function_node = node->as<FunctionNode &>();
assert(function_node.getFunctionName() == "equals");

bool lhs_const;
bool maybe_invert;

const ConstantNode * constant;
const FunctionNode * child_function;

const auto function_arguments = function_node.getArguments().getNodes();
if (function_arguments.size() != 2)
return;

const auto & lhs = function_arguments[0];
const auto & rhs = function_arguments[1];

if ((constant = lhs->as<ConstantNode>()))
lhs_const = true;
else if ((constant = rhs->as<ConstantNode>()))
lhs_const = false;
else
return;

UInt64 val;
if (!constant->getValue().tryGet<UInt64>(val))
return;

if (val == 1)
maybe_invert = false;
else if (val == 0)
maybe_invert = true;
else
return;

if (lhs_const)
child_function = rhs->as<FunctionNode>();
else
child_function = lhs->as<FunctionNode>();

if (!child_function || !isBooleanFunction(child_function->getFunctionName()))
return;
if (maybe_invert)
{
auto not_resolver = FunctionFactory::instance().get("not", getContext());
const auto not_node = std::make_shared<FunctionNode>("not");
auto & arguments = not_node->getArguments().getNodes();
arguments.reserve(1);
arguments.push_back(lhs_const ? rhs : lhs);
not_node->resolveAsFunction(not_resolver->build(not_node->getArgumentColumns()));
node = not_node;
}
else
node = lhs_const ? rhs : lhs;
}
};

void LogicalExpressionOptimizerPass::run(QueryTreeNodePtr & query_tree_node, ContextPtr context)
Expand Down
12 changes: 12 additions & 0 deletions src/Analyzer/Passes/LogicalExpressionOptimizerPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ namespace DB
*
* SELECT * FROM t1 JOIN t2 ON a <=> b
* -------------------------------
*
* 7. Remove redundant equality checks on boolean functions.
* - these requndant checks cause the primary index to not be used when if the query involves any primary key columns
* -------------------------------
* SELECT * FROM t1 WHERE a IN (n) = 1
* SELECT * FROM t1 WHERE a IN (n) = 0
*
* will be transformed into
*
* SELECT * FROM t1 WHERE a IN (n)
* SELECT * FROM t1 WHERE NOT a IN (n)
* -------------------------------
*/

class LogicalExpressionOptimizerPass final : public IQueryTreePass
Expand Down
23 changes: 23 additions & 0 deletions tests/queries/0_stateless/03032_redundant_equals.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
100
100
100
100
100
100
0
0
0
1
100
101
100
101
100
101
100
1
1
1
1
1
1
83 changes: 83 additions & 0 deletions tests/queries/0_stateless/03032_redundant_equals.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
DROP TABLE IF EXISTS test_table;

CREATE TABLE test_table
(
k UInt64,
)
ENGINE = MergeTree
ORDER BY k;

INSERT INTO test_table SELECT number FROM numbers(10000000);

SELECT * FROM test_table WHERE k in (100) = 1;
SELECT * FROM test_table WHERE k = (100) = 1;
SELECT * FROM test_table WHERE k not in (100) = 0;
SELECT * FROM test_table WHERE k != (100) = 0;
SELECT * FROM test_table WHERE 1 = (k = 100);
SELECT * FROM test_table WHERE 0 = (k not in (100));
SELECT * FROM test_table WHERE k < 1 = 1;
SELECT * FROM test_table WHERE k >= 1 = 0;
SELECT * FROM test_table WHERE k > 1 = 0;
SELECT * FROM test_table WHERE ((k not in (101) = 0) OR (k in (100) = 1)) = 1;
SELECT * FROM test_table WHERE (NOT ((k not in (100) = 0) OR (k in (100) = 1))) = 0;
SELECT * FROM test_table WHERE (NOT ((k in (101) = 0) OR (k in (100) = 1))) = 1;
SELECT * FROM test_table WHERE ((k not in (101) = 0) OR (k in (100) = 1)) = 1;
SELECT * FROM test_table WHERE ((k not in (99) = 1) AND (k in (100) = 1)) = 1;

SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k in (100) = 1
)
WHERE
explain LIKE '%Granules: 1/%';

SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k >= 1 = 0
)
WHERE
explain LIKE '%Granules: 1/%';

SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k not in (100) = 0
)
WHERE
explain LIKE '%Granules: 1/%';

SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE k > 1 = 0
)
WHERE
explain LIKE '%Granules: 1/%';

SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE (NOT ((k not in (100) = 0) OR (k in (100) = 1))) = 0
)
WHERE
explain LIKE '%Granules: 1/%';


SELECT count()
FROM
(
EXPLAIN PLAN indexes=1
SELECT * FROM test_table WHERE (NOT ((k in (101) = 0) OR (k in (100) = 1))) = 1
)
WHERE
explain LIKE '%Granules: 1/%';


DROP TABLE test_table;

0 comments on commit 9d4f1d8

Please sign in to comment.