Skip to content

Commit

Permalink
[kernel]Implement file skipping for nullSafeEquals (#4013)
Browse files Browse the repository at this point in the history
<!--
Thanks for sending a pull request!  Here are some tips for you:
1. If this is your first time, please read our contributor guidelines:
https://github.com/delta-io/delta/blob/master/CONTRIBUTING.md
2. If the PR is unfinished, add '[WIP]' in your PR title, e.g., '[WIP]
Your PR title ...'.
  3. Be sure to keep the PR description updated to reflect all changes.
  4. Please write your PR title to summarize what this PR proposes.
5. If possible, provide a concise example to reproduce the issue for a
faster review.
6. If applicable, include the corresponding issue number in the PR title
and link it in the body.
-->

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [ ] Spark
- [ ] Standalone
- [ ] Flink
- [x] Kernel
- [ ] Other (fill in here)

## Description

This PR supports data skipping for `<=>` and addresses
#2538

ideas fork the one from spark's data skipping reader:

1.Rewrite `EqualNullSafe(a, NotNullLiteral)` as`And(IsNotNull(a),
EqualTo(a, NotNullLiteral))`
2.rewrite `EqualNullSafe(a, null)` as `IsNull(a)` 

https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/stats/DataSkippingReader.scala#L508-L510

<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->

## How was this patch tested?

<!--
If tests were added, say they were added here. Please make sure to test
the changes thoroughly including negative and positive cases if
possible.
If the changes were tested in any way other than unit tests, please
clarify how you tested step by step (ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future).
If the changes were not tested, please explain why.
-->
ScanSuite.scala, adjusted test cases according to comments. for a <=> 1
in all null case, follows

https://github.com/delta-io/delta/blob/master/spark/src/test/scala/org/apache/spark/sql/delta/stats/DataSkippingDeltaTests.scala#L735

## Does this PR introduce _any_ user-facing changes?

<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
No
  • Loading branch information
huan233usc authored Jan 14, 2025
1 parent 95e1826 commit 1afa48e
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_ORDINAL;
import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_STATS_ORDINAL;
import static io.delta.kernel.internal.util.ExpressionUtils.*;
import static io.delta.kernel.internal.util.Preconditions.checkArgument;

import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.ColumnarBatch;
Expand All @@ -29,6 +30,7 @@
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;
import java.util.*;
import java.util.function.BiFunction;

public class DataSkippingUtils {

Expand Down Expand Up @@ -254,6 +256,7 @@ private static Optional<DataSkippingPredicate> constructDataSkippingFilter(
case "<=":
case ">":
case ">=":
case "IS NOT DISTINCT FROM":
Expression left = getLeft(dataFilters);
Expression right = getRight(dataFilters);

Expand All @@ -262,9 +265,8 @@ private static Optional<DataSkippingPredicate> constructDataSkippingFilter(
Literal rightLit = (Literal) right;
if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol)
&& schemaHelper.isSkippingEligibleLiteral(rightLit)) {
return Optional.of(
constructComparatorDataSkippingFilters(
dataFilters.getName(), leftCol, rightLit, schemaHelper));
return constructComparatorDataSkippingFilters(
dataFilters.getName(), leftCol, rightLit, schemaHelper);
}
} else if (right instanceof Column && left instanceof Literal) {
return constructDataSkippingFilter(reverseComparatorFilter(dataFilters), schemaHelper);
Expand All @@ -281,40 +283,47 @@ private static Optional<DataSkippingPredicate> constructDataSkippingFilter(
}

/** Construct the skipping predicate for a given comparator */
private static DataSkippingPredicate constructComparatorDataSkippingFilters(
private static Optional<DataSkippingPredicate> constructComparatorDataSkippingFilters(
String comparator, Column leftCol, Literal rightLit, StatsSchemaHelper schemaHelper) {

switch (comparator.toUpperCase(Locale.ROOT)) {

// Match any file whose min/max range contains the requested point.
case "=":
// For example a = 1 --> minValue.a <= 1 AND maxValue.a >= 1
return new DataSkippingPredicate(
"AND",
constructBinaryDataSkippingPredicate(
"<=", schemaHelper.getMinColumn(leftCol), rightLit),
constructBinaryDataSkippingPredicate(
">=", schemaHelper.getMaxColumn(leftCol), rightLit));
return Optional.of(
new DataSkippingPredicate(
"AND",
constructBinaryDataSkippingPredicate(
"<=", schemaHelper.getMinColumn(leftCol), rightLit),
constructBinaryDataSkippingPredicate(
">=", schemaHelper.getMaxColumn(leftCol), rightLit)));

// Match any file whose min is less than the requested upper bound.
case "<":
return constructBinaryDataSkippingPredicate(
"<", schemaHelper.getMinColumn(leftCol), rightLit);
return Optional.of(
constructBinaryDataSkippingPredicate(
"<", schemaHelper.getMinColumn(leftCol), rightLit));

// Match any file whose min is less than or equal to the requested upper bound
case "<=":
return constructBinaryDataSkippingPredicate(
"<=", schemaHelper.getMinColumn(leftCol), rightLit);
return Optional.of(
constructBinaryDataSkippingPredicate(
"<=", schemaHelper.getMinColumn(leftCol), rightLit));

// Match any file whose max is larger than the requested lower bound.
case ">":
return constructBinaryDataSkippingPredicate(
">", schemaHelper.getMaxColumn(leftCol), rightLit);
return Optional.of(
constructBinaryDataSkippingPredicate(
">", schemaHelper.getMaxColumn(leftCol), rightLit));

// Match any file whose max is larger than or equal to the requested lower bound.
case ">=":
return constructBinaryDataSkippingPredicate(
">=", schemaHelper.getMaxColumn(leftCol), rightLit);
return Optional.of(
constructBinaryDataSkippingPredicate(
">=", schemaHelper.getMaxColumn(leftCol), rightLit));
case "IS NOT DISTINCT FROM":
return constructDataSkippingFilter(rewriteEqualNullSafe(leftCol, rightLit), schemaHelper);
default:
throw new IllegalArgumentException(
String.format("Unsupported comparator expression %s", comparator));
Expand Down Expand Up @@ -342,6 +351,7 @@ private static DataSkippingPredicate constructBinaryDataSkippingPredicate(
put("<=", ">=");
put(">", "<");
put(">=", "<=");
put("IS NOT DISTINCT FROM", "IS NOT DISTINCT FROM");
}
};

Expand Down Expand Up @@ -402,29 +412,21 @@ private static Optional<DataSkippingPredicate> constructNotDataSkippingFilters(
new Predicate("IS_NOT_NULL", getUnaryChild(childPredicate)), schemaHelper);

case "=":
Expression left = getLeft(childPredicate);
Expression right = getRight(childPredicate);
if (left instanceof Column && right instanceof Literal) {
Column leftCol = (Column) left;
Literal rightLit = (Literal) right;
if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol)
&& schemaHelper.isSkippingEligibleLiteral(rightLit)) {
// Match any file whose min/max range contains anything other than the
// rejected point.
// For example a != 1 --> minValue.a < 1 OR maxValue.a > 1
return Optional.of(
new DataSkippingPredicate(
"OR",
constructBinaryDataSkippingPredicate(
"<", schemaHelper.getMinColumn(leftCol), rightLit),
constructBinaryDataSkippingPredicate(
">", schemaHelper.getMaxColumn(leftCol), rightLit)));
}
} else if (right instanceof Column && left instanceof Literal) {
return constructDataSkippingFilter(
new Predicate("NOT", new Predicate("=", right, left)), schemaHelper);
}
break;
return constructDataSkippingFiltersForNotEqual(
childPredicate,
schemaHelper,
(leftColumn, rightLiteral) -> {
// Match any file whose min/max range contains anything other than the
// rejected point.
// For example a != 1 --> minValue.a < 1 OR maxValue.a > 1
return Optional.of(
new DataSkippingPredicate(
"OR",
constructBinaryDataSkippingPredicate(
"<", schemaHelper.getMinColumn(leftColumn), rightLiteral),
constructBinaryDataSkippingPredicate(
">", schemaHelper.getMaxColumn(leftColumn), rightLiteral)));
});
case "<":
return constructDataSkippingFilter(
new Predicate(">=", childPredicate.getChildren()), schemaHelper);
Expand All @@ -437,6 +439,14 @@ private static Optional<DataSkippingPredicate> constructNotDataSkippingFilters(
case ">=":
return constructDataSkippingFilter(
new Predicate("<", childPredicate.getChildren()), schemaHelper);
case "IS NOT DISTINCT FROM":
return constructDataSkippingFiltersForNotEqual(
childPredicate,
schemaHelper,
(leftColumn, rightLiteral) ->
constructDataSkippingFilter(
new Predicate("NOT", rewriteEqualNullSafe(leftColumn, rightLiteral)),
schemaHelper));
case "NOT":
// Remove redundant pairs of NOT
return constructDataSkippingFilter(
Expand Down Expand Up @@ -510,4 +520,43 @@ private static String[] appendArray(String[] arr, String appendElem) {
newNames[arr.length] = appendElem;
return newNames;
}

/**
* Rewrite `EqualNullSafe(a, NotNullLiteral)` as `And(IsNotNull(a), EqualTo(a, NotNullLiteral))`
* and rewrite `EqualNullSafe(a, null)` as `IsNull(a)`
*/
private static Predicate rewriteEqualNullSafe(Column leftCol, Literal rightLit) {
if (rightLit.getValue() == null) {
return new Predicate("IS_NULL", leftCol);
}
return new Predicate(
"AND", new Predicate("IS_NOT_NULL", leftCol), new Predicate("=", leftCol, rightLit));
}

/** Helper method for building DataSkippingPredicate for NOT =/IS NOT DISTINCT FROM */
private static Optional<DataSkippingPredicate> constructDataSkippingFiltersForNotEqual(
Predicate equalPredicate,
StatsSchemaHelper schemaHelper,
BiFunction<Column, Literal, Optional<DataSkippingPredicate>> buildDataSkippingPredicateFunc) {
checkArgument(
"=".equals(equalPredicate.getName())
|| "IS NOT DISTINCT FROM".equals(equalPredicate.getName()),
"Expects predicate to be = or IS NOT DISTINCT FROM");
Expression leftChild = getLeft(equalPredicate);
Expression rightChild = getRight(equalPredicate);
if (rightChild instanceof Column && leftChild instanceof Literal) {
return constructDataSkippingFilter(
new Predicate("NOT", new Predicate(equalPredicate.getName(), rightChild, leftChild)),
schemaHelper);
}
if (leftChild instanceof Column && rightChild instanceof Literal) {
Column leftCol = (Column) leftChild;
Literal rightLit = (Literal) rightChild;
if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol)
&& schemaHelper.isSkippingEligibleLiteral(rightLit)) {
return buildDataSkippingPredicateFunc.apply(leftCol, rightLit);
}
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,7 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
nullSafeEquals(ofInt(1), col("a")), // 1 <=> a
not(nullSafeEquals(col("a"), ofInt(2))), // NOT a <=> 2
// MOVE BELOW EXPRESSIONS TO MISSES ONCE SUPPORTED BY DATA SKIPPING
not(nullSafeEquals(col("a"), ofInt(1))), // NOT a <=> 1
nullSafeEquals(col("a"), ofInt(2)), // a <=> 2
notEquals(col("a"), ofInt(1)), // a != 1
nullSafeEquals(col("a"), ofInt(2)), // a <=> 2
notEquals(ofInt(1), col("a")) // 1 != a
),
misses = Seq(
Expand All @@ -210,7 +207,11 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
lessThanOrEqual(ofInt(2), col("a")), // 2 <= a
greaterThanOrEqual(ofInt(0), col("a")), // 0 >= a
not(equals(col("a"), ofInt(1))), // NOT a = 1
not(equals(ofInt(1), col("a"))) // NOT 1 = a
not(equals(ofInt(1), col("a"))), // NOT 1 = a
not(nullSafeEquals(col("a"), ofInt(1))), // NOT a <=> 1
not(nullSafeEquals(ofInt(1), col("a"))), // NOT 1 <=> a
nullSafeEquals(ofInt(2), col("a")), // 2 <=> a
nullSafeEquals(col("a"), ofInt(2)) // a <=> 2
)
)

Expand Down Expand Up @@ -762,15 +763,12 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
lessThan(col("a"), ofInt(1)),
greaterThan(col("a"), ofInt(1)),
not(equals(col("a"), ofInt(1))),
notEquals(col("a"), ofInt(1)),
nullSafeEquals(col("a"), ofInt(1)),

// MOVE BELOW EXPRESSIONS TO MISSES ONCE SUPPORTED BY DATA SKIPPING
// This can be optimized to `IsNotNull(a)` (done by NullPropagation in Spark)
not(nullSafeEquals(col("a"), ofNull(INTEGER)))
notEquals(col("a"), ofInt(1))
),
misses = Seq(
AlwaysFalse.ALWAYS_FALSE,
nullSafeEquals(col("a"), ofInt(1)),
not(nullSafeEquals(col("a"), ofNull(INTEGER))),
isNotNull(col("a"))
)
)
Expand Down Expand Up @@ -1054,10 +1052,9 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
expNumPartitions = 1,
expNumFiles = 3) // 3 files with key = null

/*
NOT YET SUPPORTED EXPRESSIONS

checkResults(
predicate = nullSafeEquals(col("key"), ofNull(string)),
predicate = nullSafeEquals(col("key"), ofNull(STRING)),
expNumPartitions = 1,
expNumFiles = 3) // 3 files with key = null

Expand All @@ -1070,7 +1067,6 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
predicate = nullSafeEquals(col("key"), ofString("b")),
expNumPartitions = 1,
expNumFiles = 1) // 1 files with key <=> 'b'
*/

// Conditions on partitions keys and values
checkResults(
Expand All @@ -1086,7 +1082,12 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with
checkResults(
predicate = nullSafeEquals(col("value"), ofNull(STRING)),
expNumPartitions = 3,
expNumFiles = 5) // should be 3 once <=> is supported
expNumFiles = 3)

checkResults(
predicate = nullSafeEquals(ofNull(STRING), col("value")),
expNumPartitions = 3,
expNumFiles = 3)

checkResults(
predicate = equals(col("value"), ofString("a")),
Expand All @@ -1095,8 +1096,13 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with

checkResults(
predicate = nullSafeEquals(col("value"), ofString("a")),
expNumPartitions = 3, // should be 2 once <=> is supported
expNumFiles = 5) // should be 2 once <=> is supported
expNumPartitions = 2,
expNumFiles = 2)

checkResults(
predicate = nullSafeEquals(ofString("a"), col("value")),
expNumPartitions = 2,
expNumFiles = 2)

checkResults(
predicate = notEquals(col("value"), ofString("a")),
Expand All @@ -1110,8 +1116,8 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with

checkResults(
predicate = nullSafeEquals(col("value"), ofString("b")),
expNumPartitions = 3, // should be 1 once <=> is supported
expNumFiles = 5) // should be 1 once <=> is supported
expNumPartitions = 1,
expNumFiles = 1)

// Conditions on both, partition keys and values
/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ trait ExpressionTestUtils {

def str(value: String): Literal = Literal.ofString(value)

def nullSafeEquals(e1: Expression, e2: Expression): Predicate = {
new Predicate("IS NOT DISTINCT FROM", e1, e2)
}

def unsupported(colName: String): Predicate = predicate("UNSUPPORTED", col(colName));

/* ---------- NOT-YET SUPPORTED EXPRESSIONS ----------- */
Expand All @@ -70,8 +74,6 @@ trait ExpressionTestUtils {
them to expect skipped files. If they are ever actually evaluated they will throw an exception.
*/

def nullSafeEquals(e1: Expression, e2: Expression): Predicate = new Predicate("<=>", e1, e2)

def notEquals(e1: Expression, e2: Expression): Predicate = new Predicate("<>", e1, e2)

def startsWith(e1: Expression, e2: Expression): Predicate = new Predicate("STARTS_WITH", e1, e2)
Expand Down

0 comments on commit 1afa48e

Please sign in to comment.