Skip to content

Commit

Permalink
Simplify case expressions with constant boolean results
Browse files Browse the repository at this point in the history
Simplifies expressions of the form:

   CASE
     WHEN a THEN true
     WHEN b THEN false
     WHEN c THEN true
     ...
     ELSE false
   END
  • Loading branch information
martint committed Sep 10, 2024
1 parent 56790c7 commit 64e5394
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,27 @@
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.IrOptimizerRule;
import io.trino.sql.planner.Symbol;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.sql.planner.DeterminismEvaluator.isDeterministic;

/**
* Simplify CASE expressions with constant TRUE/FALSE results:
*
* <ul>
* <li>{@code Case([When(a, true), When(b, false), When(c, true)], false) -> $identical(Or(a, c), true)}
* <li>{@code Case([When(a, true), When(b, false), When(c, false)], true) -> $not($identical(Or(b, c), true)}
* </ul>
*/
public class SimplifyRedundantCase
implements IrOptimizerRule
{
Expand All @@ -47,20 +58,28 @@ public Optional<Expression> apply(Expression expression, Session session, Map<Sy
return Optional.empty();
}

// TODO: generalize to arbitrary number of clauses
if (caseTerm.whenClauses().size() != 1) {
Expression defaultValue = caseTerm.defaultValue();
if (!caseTerm.whenClauses().stream().map(WhenClause::getResult).allMatch(result -> result.equals(TRUE) || result.equals(FALSE)) ||
(!defaultValue.equals(TRUE) && !defaultValue.equals(FALSE)) ||
caseTerm.whenClauses().stream().map(WhenClause::getOperand).anyMatch(e -> !isDeterministic(e))) {
return Optional.empty();
}

WhenClause thenClause = caseTerm.whenClauses().getFirst();
if (thenClause.getResult().equals(TRUE) && caseTerm.defaultValue().equals(FALSE)) {
return Optional.of(new Comparison(Comparison.Operator.IDENTICAL, thenClause.getOperand(), TRUE));
}
if (defaultValue.equals(FALSE)) {
List<Expression> operands = caseTerm.whenClauses().stream()
.filter(clause -> clause.getResult().equals(TRUE))
.map(WhenClause::getOperand)
.toList();

if (thenClause.getResult().equals(FALSE) && caseTerm.defaultValue().equals(TRUE)) {
return Optional.of(IrExpressions.not(metadata, new Comparison(Comparison.Operator.IDENTICAL, thenClause.getOperand(), TRUE)));
return Optional.of(new Comparison(Comparison.Operator.IDENTICAL, IrUtils.or(operands), TRUE));
}
else {
List<Expression> operands = caseTerm.whenClauses().stream()
.filter(clause -> clause.getResult().equals(FALSE))
.map(WhenClause::getOperand)
.toList();

return Optional.empty();
return Optional.of(IrExpressions.not(metadata, new Comparison(Comparison.Operator.IDENTICAL, IrUtils.or(operands), TRUE)));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.metadata.ResolvedFunction;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.sql.ir.Call;
import io.trino.sql.ir.Case;
import io.trino.sql.ir.Comparison;
import io.trino.sql.ir.Constant;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.IrExpressions;
import io.trino.sql.ir.IrUtils;
import io.trino.sql.ir.Reference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.ir.optimizer.rule.SimplifyRedundantCase;
Expand All @@ -28,17 +33,22 @@

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.sql.ir.Booleans.FALSE;
import static io.trino.sql.ir.Booleans.TRUE;
import static io.trino.sql.ir.Comparison.Operator.EQUAL;
import static io.trino.sql.ir.Comparison.Operator.IDENTICAL;
import static io.trino.sql.ir.IrExpressions.not;
import static io.trino.sql.planner.TestingPlannerContext.PLANNER_CONTEXT;
import static io.trino.testing.TestingSession.testSession;
import static io.trino.transaction.InMemoryTransactionManager.createTestTransactionManager;
import static org.assertj.core.api.Assertions.assertThat;

public class TestSimplifyRedundantCase
{
private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(createTestTransactionManager(), PLANNER_CONTEXT);
private static final ResolvedFunction RANDOM = FUNCTIONS.resolveFunction("random", ImmutableList.of());

@Test
void test()
{
Expand All @@ -65,6 +75,37 @@ void test()
ImmutableList.of(new WhenClause(new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 1L)), FALSE)),
TRUE)))
.isEqualTo(Optional.of(not(PLANNER_CONTEXT.getMetadata(), new Comparison(IDENTICAL, new Comparison(EQUAL, new Reference(BIGINT, "x"), new Constant(BIGINT, 1L)), TRUE))));

assertThat(optimize(
new Case(
ImmutableList.of(
new WhenClause(new Reference(BOOLEAN, "x"), TRUE),
new WhenClause(new Reference(BOOLEAN, "y"), FALSE),
new WhenClause(new Reference(BOOLEAN, "z"), TRUE)),
FALSE)))
.isEqualTo(Optional.of(new Comparison(IDENTICAL, IrUtils.or(new Reference(BOOLEAN, "x"), new Reference(BOOLEAN, "z")), TRUE)));

assertThat(optimize(
new Case(
ImmutableList.of(
new WhenClause(new Reference(BOOLEAN, "x"), TRUE),
new WhenClause(new Reference(BOOLEAN, "y"), FALSE),
new WhenClause(new Reference(BOOLEAN, "z"), FALSE)),
TRUE)))
.isEqualTo(Optional.of(IrExpressions.not(PLANNER_CONTEXT.getMetadata(), new Comparison(IDENTICAL, IrUtils.or(new Reference(BOOLEAN, "y"), new Reference(BOOLEAN, "z")), TRUE))));
}

@Test
void testNonDeterministic()
{
assertThat(optimize(
new Case(
ImmutableList.of(
new WhenClause(new Comparison(EQUAL, new Call(RANDOM, ImmutableList.of()), new Constant(DOUBLE, 1.0)), TRUE),
new WhenClause(new Reference(BOOLEAN, "y"), FALSE),
new WhenClause(new Reference(BOOLEAN, "z"), TRUE)),
FALSE)))
.isEmpty();
}

private Optional<Expression> optimize(Expression expression)
Expand Down

0 comments on commit 64e5394

Please sign in to comment.