diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java index d93bb710286..e127dc57c5b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/skipping/DataSkippingUtils.java @@ -19,6 +19,7 @@ import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_ORDINAL; import static io.delta.kernel.internal.InternalScanFileUtils.ADD_FILE_STATS_ORDINAL; import static io.delta.kernel.internal.util.ExpressionUtils.*; +import static io.delta.kernel.internal.util.Preconditions.checkArgument; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.ColumnarBatch; @@ -29,6 +30,7 @@ import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; import java.util.*; +import java.util.function.BiFunction; public class DataSkippingUtils { @@ -254,6 +256,7 @@ private static Optional constructDataSkippingFilter( case "<=": case ">": case ">=": + case "IS NOT DISTINCT FROM": Expression left = getLeft(dataFilters); Expression right = getRight(dataFilters); @@ -262,9 +265,8 @@ private static Optional constructDataSkippingFilter( Literal rightLit = (Literal) right; if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) && schemaHelper.isSkippingEligibleLiteral(rightLit)) { - return Optional.of( - constructComparatorDataSkippingFilters( - dataFilters.getName(), leftCol, rightLit, schemaHelper)); + return constructComparatorDataSkippingFilters( + dataFilters.getName(), leftCol, rightLit, schemaHelper); } } else if (right instanceof Column && left instanceof Literal) { return constructDataSkippingFilter(reverseComparatorFilter(dataFilters), schemaHelper); @@ -281,7 +283,7 @@ private static Optional constructDataSkippingFilter( } /** Construct the skipping predicate for a given comparator */ - private static DataSkippingPredicate constructComparatorDataSkippingFilters( + private static Optional constructComparatorDataSkippingFilters( String comparator, Column leftCol, Literal rightLit, StatsSchemaHelper schemaHelper) { switch (comparator.toUpperCase(Locale.ROOT)) { @@ -289,32 +291,39 @@ private static DataSkippingPredicate constructComparatorDataSkippingFilters( // Match any file whose min/max range contains the requested point. case "=": // For example a = 1 --> minValue.a <= 1 AND maxValue.a >= 1 - return new DataSkippingPredicate( - "AND", - constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit), - constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit)); + return Optional.of( + new DataSkippingPredicate( + "AND", + constructBinaryDataSkippingPredicate( + "<=", schemaHelper.getMinColumn(leftCol), rightLit), + constructBinaryDataSkippingPredicate( + ">=", schemaHelper.getMaxColumn(leftCol), rightLit))); // Match any file whose min is less than the requested upper bound. case "<": - return constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftCol), rightLit); + return Optional.of( + constructBinaryDataSkippingPredicate( + "<", schemaHelper.getMinColumn(leftCol), rightLit)); // Match any file whose min is less than or equal to the requested upper bound case "<=": - return constructBinaryDataSkippingPredicate( - "<=", schemaHelper.getMinColumn(leftCol), rightLit); + return Optional.of( + constructBinaryDataSkippingPredicate( + "<=", schemaHelper.getMinColumn(leftCol), rightLit)); // Match any file whose max is larger than the requested lower bound. case ">": - return constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftCol), rightLit); + return Optional.of( + constructBinaryDataSkippingPredicate( + ">", schemaHelper.getMaxColumn(leftCol), rightLit)); // Match any file whose max is larger than or equal to the requested lower bound. case ">=": - return constructBinaryDataSkippingPredicate( - ">=", schemaHelper.getMaxColumn(leftCol), rightLit); + return Optional.of( + constructBinaryDataSkippingPredicate( + ">=", schemaHelper.getMaxColumn(leftCol), rightLit)); + case "IS NOT DISTINCT FROM": + return constructDataSkippingFilter(rewriteEqualNullSafe(leftCol, rightLit), schemaHelper); default: throw new IllegalArgumentException( String.format("Unsupported comparator expression %s", comparator)); @@ -342,6 +351,7 @@ private static DataSkippingPredicate constructBinaryDataSkippingPredicate( put("<=", ">="); put(">", "<"); put(">=", "<="); + put("IS NOT DISTINCT FROM", "IS NOT DISTINCT FROM"); } }; @@ -402,29 +412,21 @@ private static Optional constructNotDataSkippingFilters( new Predicate("IS_NOT_NULL", getUnaryChild(childPredicate)), schemaHelper); case "=": - Expression left = getLeft(childPredicate); - Expression right = getRight(childPredicate); - if (left instanceof Column && right instanceof Literal) { - Column leftCol = (Column) left; - Literal rightLit = (Literal) right; - if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) - && schemaHelper.isSkippingEligibleLiteral(rightLit)) { - // Match any file whose min/max range contains anything other than the - // rejected point. - // For example a != 1 --> minValue.a < 1 OR maxValue.a > 1 - return Optional.of( - new DataSkippingPredicate( - "OR", - constructBinaryDataSkippingPredicate( - "<", schemaHelper.getMinColumn(leftCol), rightLit), - constructBinaryDataSkippingPredicate( - ">", schemaHelper.getMaxColumn(leftCol), rightLit))); - } - } else if (right instanceof Column && left instanceof Literal) { - return constructDataSkippingFilter( - new Predicate("NOT", new Predicate("=", right, left)), schemaHelper); - } - break; + return constructDataSkippingFiltersForNotEqual( + childPredicate, + schemaHelper, + (leftColumn, rightLiteral) -> { + // Match any file whose min/max range contains anything other than the + // rejected point. + // For example a != 1 --> minValue.a < 1 OR maxValue.a > 1 + return Optional.of( + new DataSkippingPredicate( + "OR", + constructBinaryDataSkippingPredicate( + "<", schemaHelper.getMinColumn(leftColumn), rightLiteral), + constructBinaryDataSkippingPredicate( + ">", schemaHelper.getMaxColumn(leftColumn), rightLiteral))); + }); case "<": return constructDataSkippingFilter( new Predicate(">=", childPredicate.getChildren()), schemaHelper); @@ -437,6 +439,14 @@ private static Optional constructNotDataSkippingFilters( case ">=": return constructDataSkippingFilter( new Predicate("<", childPredicate.getChildren()), schemaHelper); + case "IS NOT DISTINCT FROM": + return constructDataSkippingFiltersForNotEqual( + childPredicate, + schemaHelper, + (leftColumn, rightLiteral) -> + constructDataSkippingFilter( + new Predicate("NOT", rewriteEqualNullSafe(leftColumn, rightLiteral)), + schemaHelper)); case "NOT": // Remove redundant pairs of NOT return constructDataSkippingFilter( @@ -510,4 +520,43 @@ private static String[] appendArray(String[] arr, String appendElem) { newNames[arr.length] = appendElem; return newNames; } + + /** + * Rewrite `EqualNullSafe(a, NotNullLiteral)` as `And(IsNotNull(a), EqualTo(a, NotNullLiteral))` + * and rewrite `EqualNullSafe(a, null)` as `IsNull(a)` + */ + private static Predicate rewriteEqualNullSafe(Column leftCol, Literal rightLit) { + if (rightLit.getValue() == null) { + return new Predicate("IS_NULL", leftCol); + } + return new Predicate( + "AND", new Predicate("IS_NOT_NULL", leftCol), new Predicate("=", leftCol, rightLit)); + } + + /** Helper method for building DataSkippingPredicate for NOT =/IS NOT DISTINCT FROM */ + private static Optional constructDataSkippingFiltersForNotEqual( + Predicate equalPredicate, + StatsSchemaHelper schemaHelper, + BiFunction> buildDataSkippingPredicateFunc) { + checkArgument( + "=".equals(equalPredicate.getName()) + || "IS NOT DISTINCT FROM".equals(equalPredicate.getName()), + "Expects predicate to be = or IS NOT DISTINCT FROM"); + Expression leftChild = getLeft(equalPredicate); + Expression rightChild = getRight(equalPredicate); + if (rightChild instanceof Column && leftChild instanceof Literal) { + return constructDataSkippingFilter( + new Predicate("NOT", new Predicate(equalPredicate.getName(), rightChild, leftChild)), + schemaHelper); + } + if (leftChild instanceof Column && rightChild instanceof Literal) { + Column leftCol = (Column) leftChild; + Literal rightLit = (Literal) rightChild; + if (schemaHelper.isSkippingEligibleMinMaxColumn(leftCol) + && schemaHelper.isSkippingEligibleLiteral(rightLit)) { + return buildDataSkippingPredicateFunc.apply(leftCol, rightLit); + } + } + return Optional.empty(); + } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index 5c6ece0efa0..2d895472868 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -192,10 +192,7 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with nullSafeEquals(ofInt(1), col("a")), // 1 <=> a not(nullSafeEquals(col("a"), ofInt(2))), // NOT a <=> 2 // MOVE BELOW EXPRESSIONS TO MISSES ONCE SUPPORTED BY DATA SKIPPING - not(nullSafeEquals(col("a"), ofInt(1))), // NOT a <=> 1 - nullSafeEquals(col("a"), ofInt(2)), // a <=> 2 notEquals(col("a"), ofInt(1)), // a != 1 - nullSafeEquals(col("a"), ofInt(2)), // a <=> 2 notEquals(ofInt(1), col("a")) // 1 != a ), misses = Seq( @@ -210,7 +207,11 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with lessThanOrEqual(ofInt(2), col("a")), // 2 <= a greaterThanOrEqual(ofInt(0), col("a")), // 0 >= a not(equals(col("a"), ofInt(1))), // NOT a = 1 - not(equals(ofInt(1), col("a"))) // NOT 1 = a + not(equals(ofInt(1), col("a"))), // NOT 1 = a + not(nullSafeEquals(col("a"), ofInt(1))), // NOT a <=> 1 + not(nullSafeEquals(ofInt(1), col("a"))), // NOT 1 <=> a + nullSafeEquals(ofInt(2), col("a")), // 2 <=> a + nullSafeEquals(col("a"), ofInt(2)) // a <=> 2 ) ) @@ -762,15 +763,12 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with lessThan(col("a"), ofInt(1)), greaterThan(col("a"), ofInt(1)), not(equals(col("a"), ofInt(1))), - notEquals(col("a"), ofInt(1)), - nullSafeEquals(col("a"), ofInt(1)), - - // MOVE BELOW EXPRESSIONS TO MISSES ONCE SUPPORTED BY DATA SKIPPING - // This can be optimized to `IsNotNull(a)` (done by NullPropagation in Spark) - not(nullSafeEquals(col("a"), ofNull(INTEGER))) + notEquals(col("a"), ofInt(1)) ), misses = Seq( AlwaysFalse.ALWAYS_FALSE, + nullSafeEquals(col("a"), ofInt(1)), + not(nullSafeEquals(col("a"), ofNull(INTEGER))), isNotNull(col("a")) ) ) @@ -1054,10 +1052,9 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with expNumPartitions = 1, expNumFiles = 3) // 3 files with key = null - /* - NOT YET SUPPORTED EXPRESSIONS + checkResults( - predicate = nullSafeEquals(col("key"), ofNull(string)), + predicate = nullSafeEquals(col("key"), ofNull(STRING)), expNumPartitions = 1, expNumFiles = 3) // 3 files with key = null @@ -1070,7 +1067,6 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with predicate = nullSafeEquals(col("key"), ofString("b")), expNumPartitions = 1, expNumFiles = 1) // 1 files with key <=> 'b' - */ // Conditions on partitions keys and values checkResults( @@ -1086,7 +1082,12 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with checkResults( predicate = nullSafeEquals(col("value"), ofNull(STRING)), expNumPartitions = 3, - expNumFiles = 5) // should be 3 once <=> is supported + expNumFiles = 3) + + checkResults( + predicate = nullSafeEquals(ofNull(STRING), col("value")), + expNumPartitions = 3, + expNumFiles = 3) checkResults( predicate = equals(col("value"), ofString("a")), @@ -1095,8 +1096,13 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with checkResults( predicate = nullSafeEquals(col("value"), ofString("a")), - expNumPartitions = 3, // should be 2 once <=> is supported - expNumFiles = 5) // should be 2 once <=> is supported + expNumPartitions = 2, + expNumFiles = 2) + + checkResults( + predicate = nullSafeEquals(ofString("a"), col("value")), + expNumPartitions = 2, + expNumFiles = 2) checkResults( predicate = notEquals(col("value"), ofString("a")), @@ -1110,8 +1116,8 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with checkResults( predicate = nullSafeEquals(col("value"), ofString("b")), - expNumPartitions = 3, // should be 1 once <=> is supported - expNumFiles = 5) // should be 1 once <=> is supported + expNumPartitions = 1, + expNumFiles = 1) // Conditions on both, partition keys and values /* diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala index 8cc1b131c56..fa43347b961 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/ExpressionTestUtils.scala @@ -59,6 +59,10 @@ trait ExpressionTestUtils { def str(value: String): Literal = Literal.ofString(value) + def nullSafeEquals(e1: Expression, e2: Expression): Predicate = { + new Predicate("IS NOT DISTINCT FROM", e1, e2) + } + def unsupported(colName: String): Predicate = predicate("UNSUPPORTED", col(colName)); /* ---------- NOT-YET SUPPORTED EXPRESSIONS ----------- */ @@ -70,8 +74,6 @@ trait ExpressionTestUtils { them to expect skipped files. If they are ever actually evaluated they will throw an exception. */ - def nullSafeEquals(e1: Expression, e2: Expression): Predicate = new Predicate("<=>", e1, e2) - def notEquals(e1: Expression, e2: Expression): Predicate = new Predicate("<>", e1, e2) def startsWith(e1: Expression, e2: Expression): Predicate = new Predicate("STARTS_WITH", e1, e2)