Skip to content

Commit

Permalink
[Spark] Add Read support for RowCommitVersion and Row Tracking Preser…
Browse files Browse the repository at this point in the history
…vation in Read/Write (#2878)

## Description
1. Adding the `row_commit_version` field to the _metadata column for
Delta tables, allowing us to read the `row_commit_version` from the file
metadata after it is stored.
2. Adding Row Tracking preservation in Read/Write.
<!--
- Describe what this PR changes.
- Describe why we need the change.
 
If this PR resolves an issue be sure to include "Resolves #XXX" to
correctly link and close the issue upon merge.
-->

## How was this patch tested?
Added UTs.
<!--
If tests were added, say they were added here. Please make sure to test
the changes thoroughly including negative and positive cases if
possible.
If the changes were tested in any way other than unit tests, please
clarify how you tested step by step (ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future).
If the changes were not tested, please explain why.
-->

## Does this PR introduce _any_ user-facing changes?
No.
<!--
If yes, please clarify the previous behavior and the change this PR
proposes - provide the console output, description and/or an example to
show the behavior difference if possible.
If possible, please also clarify if this is a user-facing change
compared to the released Delta Lake versions or within the unreleased
branches such as master.
If no, write 'No'.
-->
  • Loading branch information
longvu-db authored Apr 17, 2024
1 parent 0027d70 commit 1b210c2
Show file tree
Hide file tree
Showing 12 changed files with 695 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
.map(new Column(_))
selectExprs = selectExprs ++ rowIdExprs

val rowCommitVersionExprs = data.queryExecution.analyzed.output
.filter(RowCommitVersion.MetadataAttribute.isRowCommitVersionColumn)
.map(new Column(_))
selectExprs = selectExprs ++ rowCommitVersionExprs

val newData = queryExecution match {
case incrementalExecution: IncrementalExecution =>
selectFromStreamingDataFrame(incrementalExecution, data, selectExprs: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ trait DeltaColumnMappingBase extends DeltaLogging {

def isInternalField(field: StructField): Boolean =
DELTA_INTERNAL_COLUMNS.contains(field.name.toLowerCase(Locale.ROOT)) ||
RowIdMetadataStructField.isRowIdColumn(field)
RowIdMetadataStructField.isRowIdColumn(field) ||
RowCommitVersion.MetadataStructField.isRowCommitVersionColumn(field)

def satisfiesColumnMappingProtocol(protocol: Protocol): Boolean =
protocol.isFeatureSupported(ColumnMappingTableFeature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ object DeltaTableUtils extends PredicateHelper
}
}

/** Finds and returns the file source metadata column from a dataframe */
def getFileMetadataColumn(df: DataFrame): Column =
df.metadataColumn(FileFormat.METADATA_NAME)

/**
* Update FileFormat for a plan and return the updated plan
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ import org.apache.spark.sql.types.StructType

/**
* This rule adds a Project on top of Delta tables that support the Row tracking table feature to
* provide a default generated Row ID for rows that don't have them materialized in the data file.
* provide a default generated Row ID and row commit version for rows that don't have them
* materialized in the data file.
*/
object GenerateRowIDs extends Rule[LogicalPlan] {

Expand All @@ -49,9 +50,9 @@ object GenerateRowIDs extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithNewOutput {
case DeltaScanWithRowTrackingEnabled(scan) =>
// While Row IDs are non-nullable, we'll use the Row ID attributes to read
// the materialized values from now on, which can be null. We make
// the materialized Row ID attributes nullable in the scan here.
// While Row IDs and commit versions are non-nullable, we'll use the Row ID & commit
// version attributes to read the materialized values from now on, which can be null. We make
// the materialized Row ID & commit version attributes nullable in the scan here.

// Update nullability in the scan `metadataOutput` by updating the delta file format.
val baseRelation = scan.relation.asInstanceOf[HadoopFsRelation]
Expand Down Expand Up @@ -107,6 +108,17 @@ object GenerateRowIDs extends Rule[LogicalPlan] {
getField(metadata, ParquetFileFormat.ROW_INDEX))))
}

/**
* Expression that reads the Row commit versions from the materialized Row commit version column
* if the value is present and returns the default Row commit version from the file if not:
* coalesce(_metadata.row_commit_Version, _metadata.default_row_commit_version).
*/
private def rowCommitVersionExpr(metadata: AttributeReference): Expression = {
Coalesce(Seq(
getField(metadata, RowCommitVersion.METADATA_STRUCT_FIELD_NAME),
getField(metadata, DefaultRowCommitVersion.METADATA_STRUCT_FIELD_NAME)))
}

/**
* Extract a field from the metadata column.
*/
Expand All @@ -119,14 +131,16 @@ object GenerateRowIDs extends Rule[LogicalPlan] {
}

/**
* Create a new metadata struct where the Row ID values are populated using
* the materialized values if present, or the default Row ID values if not.
* Create a new metadata struct where the Row ID and row commit version values are populated using
* the materialized values if present, or the default Row ID / row commit version values if not.
*/
private def metadataWithRowTrackingColumnsProjection(metadata: AttributeReference)
: NamedExpression = {
val metadataFields = metadata.dataType.asInstanceOf[StructType].map {
case field if field.name == RowId.ROW_ID =>
field -> rowIdExpr(metadata)
case field if field.name == RowCommitVersion.METADATA_STRUCT_FIELD_NAME =>
field -> rowCommitVersionExpr(metadata)
case field =>
field -> getField(metadata, field.name)
}.flatMap { case (oldField, newExpr) =>
Expand Down
119 changes: 119 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/delta/RowCommitVersion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.actions.{Metadata, Protocol}
import org.apache.spark.sql.util.ScalaExtensions._

import org.apache.spark.sql.{types, Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, FileSourceGeneratedMetadataStructField}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources.FileFormat
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.types.{DataType, LongType, MetadataBuilder, StructField}

object RowCommitVersion {

val METADATA_STRUCT_FIELD_NAME = "row_commit_version"

val QUALIFIED_COLUMN_NAME = s"${FileFormat.METADATA_NAME}.$METADATA_STRUCT_FIELD_NAME"

def createMetadataStructField(protocol: Protocol, metadata: Metadata, nullable: Boolean = false)
: Option[StructField] =
MaterializedRowCommitVersion.getMaterializedColumnName(protocol, metadata)
.map(MetadataStructField(_, nullable))

/**
* Add a new column to `dataFrame` that has the name of the materialized Row Commit Version column
* and holds Row Commit Versions. The column also is tagged with the appropriate metadata such
* that it can be used to write materialized Row Commit Versions.
*/
private[delta] def preserveRowCommitVersions(
dataFrame: DataFrame,
snapshot: SnapshotDescriptor): DataFrame = {
if (!RowTracking.isEnabled(snapshot.protocol, snapshot.metadata)) {
return dataFrame
}

val materializedColumnName = MaterializedRowCommitVersion.getMaterializedColumnNameOrThrow(
snapshot.protocol, snapshot.metadata, snapshot.deltaLog.tableId)

val rowCommitVersionColumn =
DeltaTableUtils.getFileMetadataColumn(dataFrame).getField(METADATA_STRUCT_FIELD_NAME)
preserveRowCommitVersionsUnsafe(dataFrame, materializedColumnName, rowCommitVersionColumn)
}

private[delta] def preserveRowCommitVersionsUnsafe(
dataFrame: DataFrame,
materializedColumnName: String,
rowCommitVersionColumn: Column): DataFrame = {
dataFrame
.withColumn(materializedColumnName, rowCommitVersionColumn)
.withMetadata(materializedColumnName, MetadataStructField.metadata(materializedColumnName))
}

object MetadataStructField {
private val METADATA_COL_ATTR_KEY = "__row_commit_version_metadata_col"

def apply(materializedColumnName: String, nullable: Boolean = false): StructField =
StructField(
METADATA_STRUCT_FIELD_NAME,
LongType,
// The Row commit version field is used to read the materialized Row commit version value
// which is nullable. The actual Row commit version expression is created using a projection
// injected before the optimizer pass by the [[GenerateRowIDs] rule at which point the Row
// commit version field is non-nullable.
nullable,
metadata = metadata(materializedColumnName))

def unapply(field: StructField): Option[StructField] =
Option.when(isValid(field.dataType, field.metadata))(field)

def metadata(materializedColumnName: String): types.Metadata = new MetadataBuilder()
.withMetadata(
FileSourceGeneratedMetadataStructField.metadata(
METADATA_STRUCT_FIELD_NAME, materializedColumnName))
.putBoolean(METADATA_COL_ATTR_KEY, value = true)
.build()

/** Return true if the column is a Row Commit Version column. */
def isRowCommitVersionColumn(structField: StructField): Boolean =
isValid(structField.dataType, structField.metadata)

private[delta] def isValid(dataType: DataType, metadata: types.Metadata): Boolean = {
FileSourceGeneratedMetadataStructField.isValid(dataType, metadata) &&
metadata.contains(METADATA_COL_ATTR_KEY) &&
metadata.getBoolean(METADATA_COL_ATTR_KEY)
}
}

def columnMetadata(materializedColumnName: String): types.Metadata =
MetadataStructField.metadata(materializedColumnName)

object MetadataAttribute {
def apply(materializedColumnName: String): AttributeReference =
DataTypeUtils.toAttribute(MetadataStructField(materializedColumnName))
.withName(materializedColumnName)

def unapply(attr: Attribute): Option[Attribute] =
if (isRowCommitVersionColumn(attr)) Some(attr) else None

/** Return true if the column is a Row Commit Version column. */
def isRowCommitVersionColumn(attr: Attribute): Boolean =
MetadataStructField.isValid(attr.dataType, attr.metadata)
}
}
42 changes: 41 additions & 1 deletion spark/src/main/scala/org/apache/spark/sql/delta/RowId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.delta.actions.{Action, AddFile, DomainMetadata, Meta
import org.apache.spark.sql.delta.actions.TableFeatureProtocolUtils.propertyKey
import org.apache.spark.sql.util.ScalaExtensions._

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, FileSourceConstantMetadataStructField, FileSourceGeneratedMetadataStructField}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
Expand All @@ -30,6 +30,12 @@ import org.apache.spark.sql.types.{DataType, LongType, MetadataBuilder, StructFi

/**
* Collection of helpers to handle Row IDs.
*
* This file includes the following Row ID features:
* - Enabling Row IDs using table feature and table property.
* - Assigning fresh Row IDs.
* - Reading back Row IDs.
* - Preserving stable Row IDs.
*/
object RowId {
/**
Expand Down Expand Up @@ -268,4 +274,38 @@ object RowId {
case _ =>
}
}

/**
* Add a new column to 'dataFrame' that has the name of the materialized Row ID column and holds
* Row IDs. The column also is tagged with the appropriate metadata such that it can be used to
* write materialized Row IDs.
*/
private[delta] def preserveRowIds(
dataFrame: DataFrame,
snapshot: SnapshotDescriptor): DataFrame = {
if (!isEnabled(snapshot.protocol, snapshot.metadata)) {
return dataFrame
}

val materializedColumnName = MaterializedRowId.getMaterializedColumnNameOrThrow(
snapshot.protocol, snapshot.metadata, snapshot.deltaLog.tableId)

val rowIdColumn = DeltaTableUtils.getFileMetadataColumn(dataFrame).getField(ROW_ID)
preserveRowIdsUnsafe(dataFrame, materializedColumnName, rowIdColumn)
}

/**
* Add a new column to 'dataFrame' that has 'materializedColumnName' and holds Row IDs. The column
* is also tagged with the appropriate metadata so it can be used to write materialized Row IDs.
*
* Internal method, exposed only for testing.
*/
private[delta] def preserveRowIdsUnsafe(
dataFrame: DataFrame,
materializedColumnName: String,
rowIdColumn: Column): DataFrame = {
dataFrame
.withColumn(materializedColumnName, rowIdColumn)
.withMetadata(materializedColumnName, columnMetadata(materializedColumnName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.actions.{Metadata, Protocol, TableFeatureProtocolUtils}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructField

/**
Expand Down Expand Up @@ -70,6 +71,12 @@ object RowTracking {
: Iterable[StructField] = {
RowId.createRowIdField(protocol, metadata, nullable) ++
RowId.createBaseRowIdField(protocol, metadata) ++
DefaultRowCommitVersion.createDefaultRowCommitVersionField(protocol, metadata)
DefaultRowCommitVersion.createDefaultRowCommitVersionField(protocol, metadata) ++
RowCommitVersion.createMetadataStructField(protocol, metadata, nullable)
}

def preserveRowTrackingColumns(dataFrame: DataFrame, snapshot: SnapshotDescriptor): DataFrame = {
val dfWithRowIds = RowId.preserveRowIds(dataFrame, snapshot)
RowCommitVersion.preserveRowCommitVersions(dfWithRowIds, snapshot)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

package org.apache.spark.sql.delta.schema

// scalastyle:off import.ordering.noEmptyLine
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal

import org.apache.spark.sql.delta.{DeltaAnalysisException, DeltaColumnMappingMode, DeltaErrors, DeltaLog, GeneratedColumn, NoMapping, TypeWidening}
import org.apache.spark.sql.delta.RowId
import org.apache.spark.sql.delta.{RowCommitVersion, RowId}
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.commands.cdc.CDCReader
import org.apache.spark.sql.delta.metering.DeltaLogging
Expand Down Expand Up @@ -324,6 +323,8 @@ def normalizeColumnNamesInDataType(
// Consider Row Id columns internal if Row Ids are enabled.
case None if RowId.RowIdMetadataStructField.isRowIdColumn(field) =>
(field.name, None)
case None if RowCommitVersion.MetadataStructField.isRowCommitVersionColumn(field) =>
(field.name, None)
case None =>
throw DeltaErrors.cannotResolveColumn(field.name, baseSchema)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

package org.apache.spark.sql.delta.rowid

import org.apache.spark.sql.delta.{RowCommitVersion, RowId}
import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN
import org.apache.spark.sql.delta.RowId

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, Coalesce, EqualTo, Expression, FileSourceMetadataAttribute, GetStructField, MetadataAttributeWithLogicalName}
Expand Down Expand Up @@ -88,6 +88,21 @@ class GenerateRowIDsSuite extends QueryTest with RowIdTestUtils {
}
}

/**
* Checks that the given expression corresponds to the an expression used to generate Row commit
* versions:
* coalesce(_metadata.row_commit_version, _metadata.default_row_commit_version).
*/
protected def checkRowCommitVersionExpr(expr: Expression): Unit = expr match {
case Coalesce(
Seq(
GetStructField(FileSourceMetadataAttribute(_), _, _),
GetStructField(FileSourceMetadataAttribute(_), _, _))) => ()
case Alias(aliasedExpr, RowCommitVersion.METADATA_STRUCT_FIELD_NAME) =>
checkRowCommitVersionExpr(aliasedExpr)
case _ => fail(s"Expression didn't match expected Row commit version expression: $expr")
}

/**
* Checks that a metadata column is present in `output` and that it contains the given fields and
* only these.
Expand Down Expand Up @@ -162,6 +177,47 @@ class GenerateRowIDsSuite extends QueryTest with RowIdTestUtils {
}
}

testRowIdPlan("Row commit version column selected",
sql(s"SELECT _metadata.row_commit_version FROM $testTable")) {
// Selecting Row commit versions injects an expression to generate default Row commit versions.
case Project(Seq(rowIdExpr), lr: LogicalRelation) =>
assert(rowIdExpr.name == RowCommitVersion.METADATA_STRUCT_FIELD_NAME)
checkRowCommitVersionExpr(rowIdExpr)
assert(lr.output.map(_.name) === Seq("id", "_metadata"))
checkMetadataFieldsPresent(lr.output, Seq("default_row_commit_version", "row_commit_version"))
}

testRowIdPlan("Filter on Row commit version column",
sql(s"SELECT * FROM $testTable WHERE _metadata.row_commit_version = 5")) {
// Filtering on Row commit version injects an expression to generate default Row commit version
// in the filter.
case Project(projectList, Filter(EqualTo(rowIdExpr, _), lr: LogicalRelation)) =>
assert(projectList.map(_.name) === Seq("id"), "Project list didn't match")
checkRowCommitVersionExpr(rowIdExpr)
assert(lr.output.map(_.name) === Seq("id", "_metadata"), "Scan list didn't match")
checkMetadataFieldsPresent(lr.output, Seq("default_row_commit_version", "row_commit_version"))
}

testRowIdPlan("Filter on Row commit version in subquery",
sql(s"SELECT * FROM $testTable WHERE _metadata.row_commit_version IN (SELECT id FROM " +
s"$testTable)")) {
// Filtering on Row commit versions using a subquery injects an expression to generate default
// Row commit versions in the subquery.
case Project(
projectList,
Join(right: LogicalRelation, left: LogicalPlan, _, joinCond, _)) =>
assert(projectList.map(_.name) === Seq("id"), "Project list didn't match")
assert(right.output.map(_.name) === Seq("id", "_metadata"), "Outer scan output didn't match")
checkMetadataFieldsPresent(right.output,
Seq("default_row_commit_version", "row_commit_version"))
assert(left.output.map(_.name) === Seq("id"), "Subquery scan output didn't match")
joinCond match {
case Some(EqualTo(rowIdExpr, _)) =>
checkRowCommitVersionExpr(rowIdExpr)
case _ => fail(s"Subquery was transformed into a join with an unexpected condition.")
}
}

testRowIdPlan("Rename metadata column",
sql(s"SELECT renamed_metadata FROM (SELECT _metadata AS renamed_metadata FROM $testTable)"
)) {
Expand Down
Loading

0 comments on commit 1b210c2

Please sign in to comment.