Skip to content

Commit

Permalink
[Spark] Support predicates for stats that are not at the top level (d…
Browse files Browse the repository at this point in the history
…elta-io#3117)

## Description

This refactoring adds support for nested statistics columns. So far, all
statistics are keys in the stats struct in AddFiles. This PR adds
support for statistics that are part of nested structs. This is a
prerequisite for file skipping on collated string columns ([Protocol
RFC](delta-io#3068)). Statistics for
collated string columns will be wrapped in a struct keyed by the
versioned collation that was used to generate them. For example:

```
"stats": { "statsWithCollation": { "icu.en_US.72": { "minValues": { ...} } } }
```

This PR replaces statType in StatsColumn with pathToStatType, which can
be used to represent a path. This way we can re-use all of the existing
data skipping code without changes.

## How was this patch tested?
It is not possible to test this change without altering
[statsSchema](https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/stats/StatisticsCollection.scala#L285).
I would still like to ship this PR separately because the change is big
enough in itself. There is existing test coverage for stats parsing and
file skipping, but none of them uses nested statistics yet.

## Does this PR introduce _any_ user-facing changes?
No
  • Loading branch information
olaky authored and longvu-db committed May 28, 2024
1 parent 3c38d12 commit 2e9412e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,32 @@ object DataSkippingPredicateBuilder {
private [stats] class ColumnPredicateBuilder extends DataSkippingPredicateBuilder {
def equalTo(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] = {
statsProvider.getPredicateWithStatTypes(colPath, MIN, MAX) { (min, max) =>
statsProvider.getPredicateWithStatTypes(colPath, value.expr.dataType, MIN, MAX) { (min, max) =>
min <= value && value <= max
}
}

def notEqualTo(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] = {
statsProvider.getPredicateWithStatTypes(colPath, MIN, MAX) { (min, max) =>
statsProvider.getPredicateWithStatTypes(colPath, value.expr.dataType, MIN, MAX) { (min, max) =>
min < value || value < max
}
}

def lessThan(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] =
statsProvider.getPredicateWithStatType(colPath, MIN)(_ < value)
statsProvider.getPredicateWithStatType(colPath, value.expr.dataType, MIN)(_ < value)

def lessThanOrEqual(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] =
statsProvider.getPredicateWithStatType(colPath, MIN)(_ <= value)
statsProvider.getPredicateWithStatType(colPath, value.expr.dataType, MIN)(_ <= value)

def greaterThan(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] =
statsProvider.getPredicateWithStatType(colPath, MAX)(_ > value)
statsProvider.getPredicateWithStatType(colPath, value.expr.dataType, MAX)(_ > value)

def greaterThanOrEqual(statsProvider: StatsProvider, colPath: Seq[String], value: Column)
: Option[DataSkippingPredicate] =
statsProvider.getPredicateWithStatType(colPath, MAX)(_ >= value)
statsProvider.getPredicateWithStatType(colPath, value.expr.dataType, MAX)(_ >= value)
}

Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,20 @@ case class NumRecords(numPhysicalRecords: java.lang.Long, numLogicalRecords: jav
* Represents a stats column (MIN, MAX, etc) for a given (nested) user table column name. Used to
* keep track of which stats columns a data skipping query depends on.
*
* The `statType` is any value accepted by `getStatsColumnOpt()` (see object `DeltaStatistics`);
* The `pathToStatType` is path to a stats type accepted by `getStatsColumnOpt()`
* (see object `DeltaStatistics`);
* `pathToColumn` is the nested name of the user column whose stats are to be accessed.
* `columnDataType` is the data type of the column.
*/
private [stats] case class StatsColumn(
statType: String,
pathToColumn: Seq[String] = Nil)
private[stats] case class StatsColumn private(
pathToStatType: Seq[String],
pathToColumn: Seq[String])

object StatsColumn {
def apply(statType: String, pathToColumn: Seq[String], columnDataType: DataType): StatsColumn = {
StatsColumn(Seq(statType), pathToColumn)
}
}

/**
* A data skipping predicate, which includes the expression itself, plus the set of stats columns
Expand Down Expand Up @@ -236,8 +244,8 @@ trait DataSkippingReaderBase
protected def constructNotNullFilter(
statsProvider: StatsProvider,
pathToColumn: Seq[String]): Option[DataSkippingPredicate] = {
val nullCountCol = StatsColumn(NULL_COUNT, pathToColumn)
val numRecordsCol = StatsColumn(NUM_RECORDS)
val nullCountCol = StatsColumn(NULL_COUNT, pathToColumn, LongType)
val numRecordsCol = StatsColumn(NUM_RECORDS, pathToColumn = Nil, LongType)
statsProvider.getPredicateWithStatsColumns(nullCountCol, numRecordsCol) {
(nullCount, numRecords) => nullCount < numRecords
}
Expand Down Expand Up @@ -467,8 +475,8 @@ trait DataSkippingReaderBase
// Match any file whose null count is larger than zero.
// Note DVs might result in a redundant read of a file.
// However, they cannot lead to a correctness issue.
case IsNull(SkippingEligibleColumn(a, _)) =>
statsProvider.getPredicateWithStatType(a, NULL_COUNT) { nullCount =>
case IsNull(SkippingEligibleColumn(a, dt)) =>
statsProvider.getPredicateWithStatType(a, dt, NULL_COUNT) { nullCount =>
nullCount > Literal(0L)
}
case Not(IsNull(e)) =>
Expand Down Expand Up @@ -542,8 +550,8 @@ trait DataSkippingReaderBase

// Similar to an equality test, except comparing against a prefix of the min/max stats, and
// neither commutative nor invertible.
case StartsWith(SkippingEligibleColumn(a, _), v @ Literal(s: UTF8String, StringType)) =>
statsProvider.getPredicateWithStatTypes(a, MIN, MAX) { (min, max) =>
case StartsWith(SkippingEligibleColumn(a, _), v @ Literal(s: UTF8String, dt: StringType)) =>
statsProvider.getPredicateWithStatTypes(a, dt, MIN, MAX) { (min, max) =>
val sLen = s.numChars()
substring(min, 0, sLen) <= v && substring(max, 0, sLen) >= v
}
Expand Down Expand Up @@ -603,17 +611,35 @@ trait DataSkippingReaderBase
* Returns an expression to access the given statistics for a specific column, or None if that
* stats column does not exist.
*
* @param statType One of the fields declared by object `DeltaStatistics`
* @param pathToColumn The components of the nested column name to get stats for.
* @param pathToStatType Path components of one of the fields declared by the `DeltaStatistics`
* object. For statistics of collated strings, this path contains the
* versioned collation identifier. In all other cases the path only has one
* element. The path is in reverse order.
* @param pathToColumn The components of the nested column name to get stats for. The components
* are in reverse order.
*/
final protected def getStatsColumnOpt(statType: String, pathToColumn: Seq[String] = Nil)
: Option[Column] = {
// If the requested stats type doesn't even exist, just return None right away. This can
// legitimately happen if we have no stats at all, or if column stats are disabled (in which
// case only the NUM_RECORDS stat type is available).
if (!statsSchema.exists(_.name == statType)) {
return None
}
final protected def getStatsColumnOpt(
pathToStatType: Seq[String], pathToColumn: Seq[String]): Option[Column] = {

require(pathToStatType.nonEmpty, "No path to stats type provided.")

// First validate that pathToStatType is a valid path in the statsSchema. We start at the root
// of the stats schema and then follow the path. Note that the path is stored in reverse order.
// If one of the path components does not exist, the foldRight operation returns None.
val (initialColumn, initialFieldType) = pathToStatType
.foldRight(Option((getBaseStatsColumn, statsSchema.asInstanceOf[DataType]))) {
case (statTypePathComponent: String, Some((column: Column, struct: StructType))) =>
// Find the field matching the current path component name or return None otherwise.
struct.fields.collectFirst {
case StructField(name, dataType: DataType, _, _) if name == statTypePathComponent =>
(column.getField(statTypePathComponent), dataType)
}
case _ => None
}
// If the requested stats type doesn't even exist, just return None right away. This can
// legitimately happen if we have no stats at all, or if column stats are disabled (in which
// case only the NUM_RECORDS stat type is available).
.getOrElse { return None }

// Given a set of path segments in reverse order, e.g. column a.b.c is Seq("c", "b", "a"), we
// use a foldRight operation to build up the requested stats column, by successively applying
Expand All @@ -627,7 +653,7 @@ trait DataSkippingReaderBase
// step of the traversal emits the updated column, along with the stats schema and table schema
// elements corresponding to that column.
val initialState: Option[(Column, DataType, DataType)] =
Some((getBaseStatsColumn.getField(statType), statsSchema(statType).dataType, metadata.schema))
Some((initialColumn, initialFieldType, metadata.schema))
pathToColumn
.foldRight(initialState) {
// NOTE: Only match on StructType, because we cannot traverse through other DataTypes.
Expand All @@ -651,7 +677,7 @@ trait DataSkippingReaderBase
// Filter out non-leaf columns -- they lack stats so skipping predicates can't use them.
.filterNot(_._2.isInstanceOf[StructType])
.map {
case (statCol, TimestampType, _) if statType == MAX =>
case (statCol, TimestampType, _) if pathToStatType.head == MAX =>
// SC-22824: For timestamps, JSON serialization will truncate to milliseconds. This means
// that we must adjust 1 millisecond upwards for max stats, or we will incorrectly skip
// records that differ only in microsecond precision. (For example, a file containing only
Expand All @@ -661,7 +687,7 @@ trait DataSkippingReaderBase
// There is a longer term task SC-22825 to fix the serialization problem that caused this.
// But we need the adjustment in any case to correctly read stats written by old versions.
new Column(Cast(TimeAdd(statCol.expr, oneMillisecondLiteralExpr), TimestampType))
case (statCol, TimestampNTZType, _) if statType == MAX =>
case (statCol, TimestampNTZType, _) if pathToStatType.head == MAX =>
// We also apply the same adjustment of max stats that was applied to Timestamp
// for TimestampNTZ because these 2 types have the same precision in terms of time.
new Column(Cast(TimeAdd(statCol.expr, oneMillisecondLiteralExpr), TimestampNTZType))
Expand All @@ -670,22 +696,27 @@ trait DataSkippingReaderBase
}
}

/** Convenience overload for single element stat type paths. */
final protected def getStatsColumnOpt(
statType: String, pathToColumn: Seq[String] = Nil): Option[Column] =
getStatsColumnOpt(Seq(statType), pathToColumn)

/**
* Returns an expression to access the given statistics for a specific column, or a NULL
* literal expression if that column does not exist.
*/
final protected[delta] def getStatsColumnOrNullLiteral(
statType: String,
pathToColumn: Seq[String] = Nil) : Column =
getStatsColumnOpt(statType, pathToColumn).getOrElse(lit(null))
getStatsColumnOpt(Seq(statType), pathToColumn).getOrElse(lit(null))

/** Overload for convenience working with StatsColumn helpers */
final protected def getStatsColumnOpt(stat: StatsColumn): Option[Column] =
getStatsColumnOpt(stat.statType, stat.pathToColumn)
getStatsColumnOpt(stat.pathToStatType, stat.pathToColumn)

/** Overload for convenience working with StatsColumn helpers */
final protected[delta] def getStatsColumnOrNullLiteral(stat: StatsColumn): Column =
getStatsColumnOrNullLiteral(stat.statType, stat.pathToColumn)
getStatsColumnOpt(stat.pathToStatType, stat.pathToColumn).getOrElse(lit(null))

/**
* Returns an expression that can be used to check that the required statistics are present for a
Expand All @@ -708,8 +739,9 @@ trait DataSkippingReaderBase
// must return `TRUE`, and without these NULL checks it would instead return
// `NOT(NULL)` => `NULL`.
referencedStats.flatMap { stat => stat match {
case StatsColumn(MIN, _) | StatsColumn(MAX, _) =>
Seq(stat, StatsColumn(NULL_COUNT, stat.pathToColumn), StatsColumn(NUM_RECORDS))
case StatsColumn(MIN +: _, _) | StatsColumn(MAX +: _, _) =>
Seq(stat, StatsColumn(NULL_COUNT, stat.pathToColumn, LongType),
StatsColumn(NUM_RECORDS, pathToColumn = Nil, LongType))
case _ =>
Seq(stat)
}}.map{stat => stat match {
Expand All @@ -718,7 +750,7 @@ trait DataSkippingReaderBase
// NOTE: We don't care about NULL/missing NULL_COUNT and NUM_RECORDS here, because the
// separate NULL checks we emit for those columns will force the overall validation
// predicate conjunction to FALSE in that case -- AND(FALSE, <anything>) is FALSE.
case StatsColumn(MIN, _) | StatsColumn(MAX, _) =>
case StatsColumn(MIN +: _, _) | StatsColumn(MAX +: _, _) =>
getStatsColumnOrNullLiteral(stat).isNotNull ||
(getStatsColumnOrNullLiteral(NULL_COUNT, stat.pathToColumn) ===
getStatsColumnOrNullLiteral(NUM_RECORDS))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.delta.stats

import org.apache.spark.sql.Column
import org.apache.spark.sql.types.DataType

/**
* A helper class that provides the functionalities to create [[DataSkippingPredicate]] with
Expand Down Expand Up @@ -76,29 +77,32 @@ private [stats] class StatsProvider(getStat: StatsColumn => Option[Column]) {
* @return A [[DataSkippingPredicate]] with a data skipping expression, or None if the given
* stats column does not exist.
*/
def getPredicateWithStatType(pathToColumn: Seq[String], statType: String)
def getPredicateWithStatType(
pathToColumn: Seq[String], columnDataType: DataType, statType: String)
(f: Column => Column): Option[DataSkippingPredicate] = {
getPredicateWithStatsColumn(StatsColumn(statType, pathToColumn))(f)
getPredicateWithStatsColumn(StatsColumn(statType, pathToColumn, columnDataType))(f)
}

/** A variant of [[getPredicateWithStatType]] with two stat types. */
def getPredicateWithStatTypes(pathToColumn: Seq[String], statType1: String, statType2: String)
def getPredicateWithStatTypes(
pathToColumn: Seq[String], columnDataType: DataType, statType1: String, statType2: String)
(f: (Column, Column) => Column): Option[DataSkippingPredicate] = {
getPredicateWithStatsColumns(
StatsColumn(statType1, pathToColumn),
StatsColumn(statType2, pathToColumn))(f)
StatsColumn(statType1, pathToColumn, columnDataType),
StatsColumn(statType2, pathToColumn, columnDataType))(f)
}

/** A variant of [[getPredicateWithStatType]] with three stat types. */
def getPredicateWithStatTypes(
pathToColumn: Seq[String],
columnDataType: DataType,
statType1: String,
statType2: String,
statType3: String)
(f: (Column, Column, Column) => Column): Option[DataSkippingPredicate] = {
getPredicateWithStatsColumns(
StatsColumn(statType1, pathToColumn),
StatsColumn(statType2, pathToColumn),
StatsColumn(statType3, pathToColumn))(f)
StatsColumn(statType1, pathToColumn, columnDataType),
StatsColumn(statType2, pathToColumn, columnDataType),
StatsColumn(statType3, pathToColumn, columnDataType))(f)
}
}

0 comments on commit 2e9412e

Please sign in to comment.