Skip to content

Commit

Permalink
[Spark] Identity Columns Value Generation (without MERGE support) (#3023
Browse files Browse the repository at this point in the history
)

#### Which Delta project/connector is this regarding?
- [x] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description
This PR is part of #1959

In this PR, we enable basic ingestion for Identity Columns. 
* We use a custom UDF `GenerateIdentityValues` to generate values when
not supplemented by the user.
* We introduce classes to help update and track the high watermark of
identity columns.
* We also do some cleanup/ improve readability for
ColumnWithDefaultExprUtils

Note: This does NOT enable Ingestion with MERGE INTO yet. That will come
in a follow up PR, to make this easier to review.

## How was this patch tested?
We introduce a new test suite IdentityColumnIngestionSuite.

## Does this PR introduce _any_ user-facing changes?
No.
  • Loading branch information
c27kwan authored Jul 19, 2024
1 parent 589caba commit 2c450fe
Show file tree
Hide file tree
Showing 7 changed files with 702 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}

/**
* Provide utilities to handle columns with default expressions.
* Currently we support three types of such columns:
* (1) GENERATED columns.
* (2) IDENTITY columns.
* (3) Columns with user-specified default value expression.
*/
object ColumnWithDefaultExprUtils extends DeltaLogging {
val USE_NULL_AS_DEFAULT_DELTA_OPTION = "__use_null_as_default"
Expand All @@ -60,6 +64,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {

// Return if `protocol` satisfies the requirement for IDENTITY columns.
def satisfiesIdentityColumnProtocol(protocol: Protocol): Boolean =
protocol.isFeatureSupported(IdentityColumnsTableFeature) ||
protocol.minWriterVersion == 6 || protocol.writerFeatureNames.contains("identityColumns")

// Return true if the column `col` has default expressions (and can thus be omitted from the
Expand All @@ -68,22 +73,18 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
protocol: Protocol,
col: StructField,
nullAsDefault: Boolean): Boolean = {
isIdentityColumn(col) ||
col.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) ||
(col.nullable && nullAsDefault) ||
GeneratedColumn.isGeneratedColumn(protocol, col)
}

// Return true if the column `col` cannot be included as the input data column of COPY INTO.
// TODO: ideally column with default value can be optionally excluded.
def shouldBeExcludedInCopyInto(protocol: Protocol, col: StructField): Boolean = {
GeneratedColumn.isGeneratedColumn(protocol, col)
}

// Return true if the table with `metadata` has default expressions.
def tableHasDefaultExpr(
protocol: Protocol,
metadata: Metadata,
nullAsDefault: Boolean): Boolean = {
hasIdentityColumn(metadata.schema) ||
metadata.schema.exists { f =>
f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) ||
(f.nullable && nullAsDefault)
Expand All @@ -102,7 +103,8 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
* @param data The data to be written into the table.
* @param nullAsDefault If true, use null literal as the default value for missing columns.
* @return The data with potentially additional default expressions projected and constraints
* from generated columns if any.
* from generated columns if any. This includes IDENTITY column names for which we
* should track the high water marks.
*/
def addDefaultExprsOrReturnConstraints(
deltaLog: DeltaLog,
Expand All @@ -114,6 +116,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
val topLevelOutputNames = CaseInsensitiveMap(data.schema.map(f => f.name -> f).toMap)
lazy val metadataOutputNames = CaseInsensitiveMap(schema.map(f => f.name -> f).toMap)
val constraints = mutable.ArrayBuffer[Constraint]()
// Column names for which we will track high water marks.
val track = mutable.Set[String]()
var selectExprs = schema.flatMap { f =>
GeneratedColumn.getGenerationExpression(f) match {
Expand All @@ -128,6 +131,15 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
Some(new Column(expr).alias(f.name))
}
case _ =>
if (isIdentityColumn(f)) {
if (topLevelOutputNames.contains(f.name)) {
Some(SchemaUtils.fieldToColumn(f))
} else {
// Track high water marks for generated IDENTITY values.
track += f.name
Some(IdentityColumn.createIdentityColumnGenerationExprAsColumn(f))
}
} else {
if (topLevelOutputNames.contains(f.name) ||
!data.sparkSession.conf.get(DeltaSQLConf.GENERATED_COLUMN_ALLOW_NULLABLE)) {
Some(SchemaUtils.fieldToColumn(f))
Expand All @@ -137,6 +149,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging {
// The actual check for nullability on data is done in the DeltaInvariantCheckerExec
getDefaultValueExprOrNullLit(f, nullAsDefault).map(new Column(_))
}
}
}
}
val cdcSelectExprs = CDCReader.CDC_COLUMNS_IN_DATA.flatMap { cdcColumnName =>
Expand Down
232 changes: 232 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@
package org.apache.spark.sql.delta

import org.apache.spark.sql.delta.metering.DeltaLogging
import org.apache.spark.sql.delta.sources.DeltaSourceUtils._
import org.apache.spark.sql.delta.stats.{DeltaFileStatistics, DeltaJobStatisticsTracker}
import org.apache.spark.sql.delta.util.JsonUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.datasources.WriteTaskStats
import org.apache.spark.sql.functions.{array, max, min, to_json}
import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}

/**
* Provide utility methods related to IDENTITY column support for Delta.
Expand All @@ -26,4 +39,223 @@ object IdentityColumn extends DeltaLogging {
// Default start and step configuration if not specified by user.
val defaultStart = 1
val defaultStep = 1
// Operation types in usage logs.
// When IDENTITY columns are defined.
val opTypeDefinition = "delta.identityColumn.definition"
// When table with IDENTITY columns are written into.
val opTypeWrite = "delta.identityColumn.write"

// Return all the IDENTITY columns from `schema`.
def getIdentityColumns(schema: StructType): Seq[StructField] = {
schema.filter(ColumnWithDefaultExprUtils.isIdentityColumn)
}

// Return the number of IDENTITY columns in `schema`.
private def getNumberOfIdentityColumns(schema: StructType): Int = {
getIdentityColumns(schema).size
}

// Create expression to generate IDENTITY values for the column `field`.
def createIdentityColumnGenerationExpr(field: StructField): Expression = {
val info = IdentityColumn.getIdentityInfo(field)
GenerateIdentityValues(info.start, info.step, info.highWaterMark)
}

// Create a column to generate IDENTITY values for the column `field`.
def createIdentityColumnGenerationExprAsColumn(field: StructField): Column = {
new Column(createIdentityColumnGenerationExpr(field)).alias(field.name)
}

/**
* Create a stats tracker to collect IDENTITY column high water marks if its values are system
* generated.
*
* @param spark The SparkSession associated with this query.
* @param hadoopConf The Hadoop configuration object to use on an executor.
* @param path Root Reservoir path
* @param schema The schema of the table to be written into.
* @param statsDataSchema The schema of the output data (this does not include partition columns).
* @param trackHighWaterMarks Column names for which we should track high water marks.
* @return The stats tracker.
*/
def createIdentityColumnStatsTracker(
spark: SparkSession,
hadoopConf: Configuration,
path: Path,
schema: StructType,
statsDataSchema: Seq[Attribute],
trackHighWaterMarks: Set[String]
) : Option[DeltaIdentityColumnStatsTracker] = {
if (trackHighWaterMarks.isEmpty) return None
val identityColumnInfo = schema
.filter(f => trackHighWaterMarks.contains(f.name))
.map(f => DeltaColumnMapping.getPhysicalName(f) -> // Get identity column physical names
(f.metadata.getLong(IDENTITY_INFO_STEP) > 0L))
// We should have found all IDENTITY columns to track high water marks.
assert(identityColumnInfo.size == trackHighWaterMarks.size,
s"expect: $trackHighWaterMarks, found (physical names): ${identityColumnInfo.map(_._1)}")
// Build the expression to collect high water marks of all IDENTITY columns as a single
// expression. It is essentially a json array containing one max or min aggregate expression
// for each IDENTITY column.
//
// Example: for the following table
//
// CREATE TABLE t1 (
// id1 BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY 1),
// id2 BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY -1),
// value STRING
// ) USING delta;
//
// The expression will be: to_json(array(max(id1), min(id2)))
val aggregates = identityColumnInfo.map {
case (name, positiveStep) =>
val col = new Column(UnresolvedAttribute.quoted(name))
if (positiveStep) max(col) else min(col)
}
val unresolvedExpr = to_json(array(aggregates: _*))
// Resolve the collection expression by constructing a query to select the expression from a
// table with the statsSchema and get the analyzed expression.
val resolvedExpr = Dataset.ofRows(spark, LocalRelation(statsDataSchema))
.select(unresolvedExpr).queryExecution.analyzed.expressions.head
Some(new DeltaIdentityColumnStatsTracker(
hadoopConf,
path,
statsDataSchema,
resolvedExpr,
identityColumnInfo
))
}

/**
* Return a new schema with IDENTITY high water marks updated in the schema.
* The new high watermarks are decided based on the `updatedIdentityHighWaterMarks` and old high
* watermark values present in the passed `schema`.
*/
def updateSchema(
schema: StructType,
updatedIdentityHighWaterMarks: Seq[(String, Long)]) : StructType = {
val updatedIdentityHighWaterMarksGrouped =
updatedIdentityHighWaterMarks.groupBy(_._1).mapValues(v => v.map(_._2))
StructType(schema.map { f =>
updatedIdentityHighWaterMarksGrouped.get(DeltaColumnMapping.getPhysicalName(f)) match {
case Some(newWatermarks) if ColumnWithDefaultExprUtils.isIdentityColumn(f) =>
val oldIdentityInfo = getIdentityInfo(f)
val positiveStep = oldIdentityInfo.step > 0
val newHighWaterMark = if (positiveStep) {
oldIdentityInfo.highWaterMark.map(Math.max(_, newWatermarks.max))
.getOrElse(newWatermarks.max)
} else {
oldIdentityInfo.highWaterMark.map(Math.min(_, newWatermarks.min))
.getOrElse(newWatermarks.min)
}
val builder = new MetadataBuilder()
.withMetadata(f.metadata)
.putLong(IDENTITY_INFO_HIGHWATERMARK, newHighWaterMark)
f.copy(metadata = builder.build())
case _ =>
f
}
})
}
def logTableCreation(deltaLog: DeltaLog, schema: StructType): Unit = {
val numIdentityColumns = getNumberOfIdentityColumns(schema)
if (numIdentityColumns != 0) {
recordDeltaEvent(
deltaLog,
opTypeDefinition,
data = Map(
"numIdentityColumns" -> numIdentityColumns
)
)
}
}

def logTableWrite(
snapshot: Snapshot,
generatedIdentityColumns: Set[String],
numInsertedRowsOpt: Option[Long]): Unit = {
val identityColumns = getIdentityColumns(snapshot.schema)
if (identityColumns.nonEmpty) {
val explicitIdentityColumns = identityColumns.filter {
f => !generatedIdentityColumns.contains(f.name)
}.map(_.name)
recordDeltaEvent(
snapshot.deltaLog,
opTypeWrite,
data = Map(
"numInsertedRows" -> numInsertedRowsOpt,
"generatedIdentityColumnNames" -> generatedIdentityColumns.mkString(","),
"generatedIdentityColumnCount" -> generatedIdentityColumns.size,
"explicitIdentityColumnNames" -> explicitIdentityColumns.mkString(","),
"explicitIdentityColumnCount" -> explicitIdentityColumns.size
)
)
}
}
// Return IDENTITY information of column `field`. Caller must ensure `isIdentityColumn(field)`
// is true.
def getIdentityInfo(field: StructField): IdentityInfo = {
val md = field.metadata
val start = md.getLong(IDENTITY_INFO_START)
val step = md.getLong(IDENTITY_INFO_STEP)
// If system hasn't generated IDENTITY values for this column (either it hasn't been
// inserted into, or every inserts provided values for this IDENTITY column), high water mark
// field will not present in column metadata. In this case, high water mark will be set to
// (start - step) so that the first value generated is start (high water mark + step).
val highWaterMark = if (md.contains(IDENTITY_INFO_HIGHWATERMARK)) {
Some(md.getLong(IDENTITY_INFO_HIGHWATERMARK))
} else {
None
}
IdentityInfo(start, step, highWaterMark)
}
}

/**
* Stats tracker for IDENTITY column high water marks. The only difference between this class and
* `DeltaJobStatisticsTracker` is how the stats are aggregated on the driver.
*
* @param hadoopConf The Hadoop configuration object to use on an executor.
* @param path Root Reservoir path
* @param dataCols Resolved data (i.e. non-partitionBy) columns of the dataframe to be written.
* @param statsColExpr The expression to collect high water marks.
* @param identityColumnInfo Information of IDENTITY columns. It contains a pair of column name
* and whether it has a positive step for each IDENTITY column.
*/
class DeltaIdentityColumnStatsTracker(
@transient private val hadoopConf: Configuration,
@transient path: Path,
dataCols: Seq[Attribute],
statsColExpr: Expression,
val identityColumnInfo: Seq[(String, Boolean)]
)
extends DeltaJobStatisticsTracker(
hadoopConf,
path,
dataCols,
statsColExpr
) {

// Map of column name to its corresponding collected high water mark.
var highWaterMarks = scala.collection.mutable.Map[String, Long]()

// Process the stats on the driver. In `stats` we have a sequence of `DeltaFileStatistics`,
// whose stats is a map of file path to its corresponding array of high water marks in json.
override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = {
stats.map(_.asInstanceOf[DeltaFileStatistics]).flatMap(_.stats).map {
case (_, statsString) =>
val fileHighWaterMarks = JsonUtils.fromJson[Array[Long]](statsString)
// We must have high water marks collected for all IDENTITY columns and we have guaranteed
// that their orders in the array follow the orders in `identityInfo` by aligning the
// order of expression and `identityColumnInfo` in `createIdentityColumnStatsTracker`.
require(fileHighWaterMarks.size == identityColumnInfo.size)
identityColumnInfo.zip(fileHighWaterMarks).map {
case ((name, positiveStep), value) =>
val updated = highWaterMarks.get(name).map { v =>
if (positiveStep) v.max(value) else v.min(value)
}.getOrElse(value)
highWaterMarks.update(name, updated)
}
}
}
}
Loading

0 comments on commit 2c450fe

Please sign in to comment.