From 2c450feba1b9e26ac2b3c6019a6bdf42a70583ce Mon Sep 17 00:00:00 2001 From: Carmen Kwan Date: Fri, 19 Jul 2024 23:59:30 +0200 Subject: [PATCH] [Spark] Identity Columns Value Generation (without MERGE support) (#3023) #### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This PR is part of https://github.com/delta-io/delta/issues/1959 In this PR, we enable basic ingestion for Identity Columns. * We use a custom UDF `GenerateIdentityValues` to generate values when not supplemented by the user. * We introduce classes to help update and track the high watermark of identity columns. * We also do some cleanup/ improve readability for ColumnWithDefaultExprUtils Note: This does NOT enable Ingestion with MERGE INTO yet. That will come in a follow up PR, to make this easier to review. ## How was this patch tested? We introduce a new test suite IdentityColumnIngestionSuite. ## Does this PR introduce _any_ user-facing changes? No. --- .../delta/ColumnWithDefaultExprUtils.scala | 27 +- .../spark/sql/delta/IdentityColumn.scala | 232 ++++++++++++++ .../sql/delta/OptimisticTransaction.scala | 56 +++- .../sql/delta/commands/WriteIntoDelta.scala | 22 +- .../sql/delta/files/TransactionalWrite.scala | 31 +- .../delta/IdentityColumnIngestionSuite.scala | 290 ++++++++++++++++++ .../sql/delta/IdentityColumnTestUtils.scala | 55 ++++ 7 files changed, 702 insertions(+), 11 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnIngestionSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala index c666b09b0d6..790da73f567 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/ColumnWithDefaultExprUtils.scala @@ -38,6 +38,10 @@ import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} /** * Provide utilities to handle columns with default expressions. + * Currently we support three types of such columns: + * (1) GENERATED columns. + * (2) IDENTITY columns. + * (3) Columns with user-specified default value expression. */ object ColumnWithDefaultExprUtils extends DeltaLogging { val USE_NULL_AS_DEFAULT_DELTA_OPTION = "__use_null_as_default" @@ -60,6 +64,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { // Return if `protocol` satisfies the requirement for IDENTITY columns. def satisfiesIdentityColumnProtocol(protocol: Protocol): Boolean = + protocol.isFeatureSupported(IdentityColumnsTableFeature) || protocol.minWriterVersion == 6 || protocol.writerFeatureNames.contains("identityColumns") // Return true if the column `col` has default expressions (and can thus be omitted from the @@ -68,22 +73,18 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { protocol: Protocol, col: StructField, nullAsDefault: Boolean): Boolean = { + isIdentityColumn(col) || col.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) || (col.nullable && nullAsDefault) || GeneratedColumn.isGeneratedColumn(protocol, col) } - // Return true if the column `col` cannot be included as the input data column of COPY INTO. - // TODO: ideally column with default value can be optionally excluded. - def shouldBeExcludedInCopyInto(protocol: Protocol, col: StructField): Boolean = { - GeneratedColumn.isGeneratedColumn(protocol, col) - } - // Return true if the table with `metadata` has default expressions. def tableHasDefaultExpr( protocol: Protocol, metadata: Metadata, nullAsDefault: Boolean): Boolean = { + hasIdentityColumn(metadata.schema) || metadata.schema.exists { f => f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) || (f.nullable && nullAsDefault) @@ -102,7 +103,8 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { * @param data The data to be written into the table. * @param nullAsDefault If true, use null literal as the default value for missing columns. * @return The data with potentially additional default expressions projected and constraints - * from generated columns if any. + * from generated columns if any. This includes IDENTITY column names for which we + * should track the high water marks. */ def addDefaultExprsOrReturnConstraints( deltaLog: DeltaLog, @@ -114,6 +116,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { val topLevelOutputNames = CaseInsensitiveMap(data.schema.map(f => f.name -> f).toMap) lazy val metadataOutputNames = CaseInsensitiveMap(schema.map(f => f.name -> f).toMap) val constraints = mutable.ArrayBuffer[Constraint]() + // Column names for which we will track high water marks. val track = mutable.Set[String]() var selectExprs = schema.flatMap { f => GeneratedColumn.getGenerationExpression(f) match { @@ -128,6 +131,15 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { Some(new Column(expr).alias(f.name)) } case _ => + if (isIdentityColumn(f)) { + if (topLevelOutputNames.contains(f.name)) { + Some(SchemaUtils.fieldToColumn(f)) + } else { + // Track high water marks for generated IDENTITY values. + track += f.name + Some(IdentityColumn.createIdentityColumnGenerationExprAsColumn(f)) + } + } else { if (topLevelOutputNames.contains(f.name) || !data.sparkSession.conf.get(DeltaSQLConf.GENERATED_COLUMN_ALLOW_NULLABLE)) { Some(SchemaUtils.fieldToColumn(f)) @@ -137,6 +149,7 @@ object ColumnWithDefaultExprUtils extends DeltaLogging { // The actual check for nullability on data is done in the DeltaInvariantCheckerExec getDefaultValueExprOrNullLit(f, nullAsDefault).map(new Column(_)) } + } } } val cdcSelectExprs = CDCReader.CDC_COLUMNS_IN_DATA.flatMap { cdcColumnName => diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala b/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala index e92eeb5b09d..e4e70c93b6f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala @@ -17,6 +17,19 @@ package org.apache.spark.sql.delta import org.apache.spark.sql.delta.metering.DeltaLogging +import org.apache.spark.sql.delta.sources.DeltaSourceUtils._ +import org.apache.spark.sql.delta.stats.{DeltaFileStatistics, DeltaJobStatisticsTracker} +import org.apache.spark.sql.delta.util.JsonUtils +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{Column, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.execution.datasources.WriteTaskStats +import org.apache.spark.sql.functions.{array, max, min, to_json} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} /** * Provide utility methods related to IDENTITY column support for Delta. @@ -26,4 +39,223 @@ object IdentityColumn extends DeltaLogging { // Default start and step configuration if not specified by user. val defaultStart = 1 val defaultStep = 1 + // Operation types in usage logs. + // When IDENTITY columns are defined. + val opTypeDefinition = "delta.identityColumn.definition" + // When table with IDENTITY columns are written into. + val opTypeWrite = "delta.identityColumn.write" + + // Return all the IDENTITY columns from `schema`. + def getIdentityColumns(schema: StructType): Seq[StructField] = { + schema.filter(ColumnWithDefaultExprUtils.isIdentityColumn) + } + + // Return the number of IDENTITY columns in `schema`. + private def getNumberOfIdentityColumns(schema: StructType): Int = { + getIdentityColumns(schema).size + } + + // Create expression to generate IDENTITY values for the column `field`. + def createIdentityColumnGenerationExpr(field: StructField): Expression = { + val info = IdentityColumn.getIdentityInfo(field) + GenerateIdentityValues(info.start, info.step, info.highWaterMark) + } + + // Create a column to generate IDENTITY values for the column `field`. + def createIdentityColumnGenerationExprAsColumn(field: StructField): Column = { + new Column(createIdentityColumnGenerationExpr(field)).alias(field.name) + } + + /** + * Create a stats tracker to collect IDENTITY column high water marks if its values are system + * generated. + * + * @param spark The SparkSession associated with this query. + * @param hadoopConf The Hadoop configuration object to use on an executor. + * @param path Root Reservoir path + * @param schema The schema of the table to be written into. + * @param statsDataSchema The schema of the output data (this does not include partition columns). + * @param trackHighWaterMarks Column names for which we should track high water marks. + * @return The stats tracker. + */ + def createIdentityColumnStatsTracker( + spark: SparkSession, + hadoopConf: Configuration, + path: Path, + schema: StructType, + statsDataSchema: Seq[Attribute], + trackHighWaterMarks: Set[String] + ) : Option[DeltaIdentityColumnStatsTracker] = { + if (trackHighWaterMarks.isEmpty) return None + val identityColumnInfo = schema + .filter(f => trackHighWaterMarks.contains(f.name)) + .map(f => DeltaColumnMapping.getPhysicalName(f) -> // Get identity column physical names + (f.metadata.getLong(IDENTITY_INFO_STEP) > 0L)) + // We should have found all IDENTITY columns to track high water marks. + assert(identityColumnInfo.size == trackHighWaterMarks.size, + s"expect: $trackHighWaterMarks, found (physical names): ${identityColumnInfo.map(_._1)}") + // Build the expression to collect high water marks of all IDENTITY columns as a single + // expression. It is essentially a json array containing one max or min aggregate expression + // for each IDENTITY column. + // + // Example: for the following table + // + // CREATE TABLE t1 ( + // id1 BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY 1), + // id2 BIGINT GENERATED BY DEFAULT AS IDENTITY (START WITH 1 INCREMENT BY -1), + // value STRING + // ) USING delta; + // + // The expression will be: to_json(array(max(id1), min(id2))) + val aggregates = identityColumnInfo.map { + case (name, positiveStep) => + val col = new Column(UnresolvedAttribute.quoted(name)) + if (positiveStep) max(col) else min(col) + } + val unresolvedExpr = to_json(array(aggregates: _*)) + // Resolve the collection expression by constructing a query to select the expression from a + // table with the statsSchema and get the analyzed expression. + val resolvedExpr = Dataset.ofRows(spark, LocalRelation(statsDataSchema)) + .select(unresolvedExpr).queryExecution.analyzed.expressions.head + Some(new DeltaIdentityColumnStatsTracker( + hadoopConf, + path, + statsDataSchema, + resolvedExpr, + identityColumnInfo + )) + } + + /** + * Return a new schema with IDENTITY high water marks updated in the schema. + * The new high watermarks are decided based on the `updatedIdentityHighWaterMarks` and old high + * watermark values present in the passed `schema`. + */ + def updateSchema( + schema: StructType, + updatedIdentityHighWaterMarks: Seq[(String, Long)]) : StructType = { + val updatedIdentityHighWaterMarksGrouped = + updatedIdentityHighWaterMarks.groupBy(_._1).mapValues(v => v.map(_._2)) + StructType(schema.map { f => + updatedIdentityHighWaterMarksGrouped.get(DeltaColumnMapping.getPhysicalName(f)) match { + case Some(newWatermarks) if ColumnWithDefaultExprUtils.isIdentityColumn(f) => + val oldIdentityInfo = getIdentityInfo(f) + val positiveStep = oldIdentityInfo.step > 0 + val newHighWaterMark = if (positiveStep) { + oldIdentityInfo.highWaterMark.map(Math.max(_, newWatermarks.max)) + .getOrElse(newWatermarks.max) + } else { + oldIdentityInfo.highWaterMark.map(Math.min(_, newWatermarks.min)) + .getOrElse(newWatermarks.min) + } + val builder = new MetadataBuilder() + .withMetadata(f.metadata) + .putLong(IDENTITY_INFO_HIGHWATERMARK, newHighWaterMark) + f.copy(metadata = builder.build()) + case _ => + f + } + }) + } + def logTableCreation(deltaLog: DeltaLog, schema: StructType): Unit = { + val numIdentityColumns = getNumberOfIdentityColumns(schema) + if (numIdentityColumns != 0) { + recordDeltaEvent( + deltaLog, + opTypeDefinition, + data = Map( + "numIdentityColumns" -> numIdentityColumns + ) + ) + } + } + + def logTableWrite( + snapshot: Snapshot, + generatedIdentityColumns: Set[String], + numInsertedRowsOpt: Option[Long]): Unit = { + val identityColumns = getIdentityColumns(snapshot.schema) + if (identityColumns.nonEmpty) { + val explicitIdentityColumns = identityColumns.filter { + f => !generatedIdentityColumns.contains(f.name) + }.map(_.name) + recordDeltaEvent( + snapshot.deltaLog, + opTypeWrite, + data = Map( + "numInsertedRows" -> numInsertedRowsOpt, + "generatedIdentityColumnNames" -> generatedIdentityColumns.mkString(","), + "generatedIdentityColumnCount" -> generatedIdentityColumns.size, + "explicitIdentityColumnNames" -> explicitIdentityColumns.mkString(","), + "explicitIdentityColumnCount" -> explicitIdentityColumns.size + ) + ) + } + } + // Return IDENTITY information of column `field`. Caller must ensure `isIdentityColumn(field)` + // is true. + def getIdentityInfo(field: StructField): IdentityInfo = { + val md = field.metadata + val start = md.getLong(IDENTITY_INFO_START) + val step = md.getLong(IDENTITY_INFO_STEP) + // If system hasn't generated IDENTITY values for this column (either it hasn't been + // inserted into, or every inserts provided values for this IDENTITY column), high water mark + // field will not present in column metadata. In this case, high water mark will be set to + // (start - step) so that the first value generated is start (high water mark + step). + val highWaterMark = if (md.contains(IDENTITY_INFO_HIGHWATERMARK)) { + Some(md.getLong(IDENTITY_INFO_HIGHWATERMARK)) + } else { + None + } + IdentityInfo(start, step, highWaterMark) + } +} + +/** + * Stats tracker for IDENTITY column high water marks. The only difference between this class and + * `DeltaJobStatisticsTracker` is how the stats are aggregated on the driver. + * + * @param hadoopConf The Hadoop configuration object to use on an executor. + * @param path Root Reservoir path + * @param dataCols Resolved data (i.e. non-partitionBy) columns of the dataframe to be written. + * @param statsColExpr The expression to collect high water marks. + * @param identityColumnInfo Information of IDENTITY columns. It contains a pair of column name + * and whether it has a positive step for each IDENTITY column. + */ +class DeltaIdentityColumnStatsTracker( + @transient private val hadoopConf: Configuration, + @transient path: Path, + dataCols: Seq[Attribute], + statsColExpr: Expression, + val identityColumnInfo: Seq[(String, Boolean)] + ) + extends DeltaJobStatisticsTracker( + hadoopConf, + path, + dataCols, + statsColExpr + ) { + + // Map of column name to its corresponding collected high water mark. + var highWaterMarks = scala.collection.mutable.Map[String, Long]() + + // Process the stats on the driver. In `stats` we have a sequence of `DeltaFileStatistics`, + // whose stats is a map of file path to its corresponding array of high water marks in json. + override def processStats(stats: Seq[WriteTaskStats], jobCommitTime: Long): Unit = { + stats.map(_.asInstanceOf[DeltaFileStatistics]).flatMap(_.stats).map { + case (_, statsString) => + val fileHighWaterMarks = JsonUtils.fromJson[Array[Long]](statsString) + // We must have high water marks collected for all IDENTITY columns and we have guaranteed + // that their orders in the array follow the orders in `identityInfo` by aligning the + // order of expression and `identityColumnInfo` in `createIdentityColumnStatsTracker`. + require(fileHighWaterMarks.size == identityColumnInfo.size) + identityColumnInfo.zip(fileHighWaterMarks).map { + case ((name, positiveStep), value) => + val updated = highWaterMarks.get(name).map { v => + if (positiveStep) v.max(value) else v.min(value) + }.getOrElse(value) + highWaterMarks.update(name, updated) + } + } + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala index 670a171afe5..e9116d389f3 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.delta.implicits.addFileEncoder import org.apache.spark.sql.delta.logging.DeltaLogKeys import org.apache.spark.sql.delta.metering.DeltaLogging import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils} -import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.sql.delta.sources.{DeltaSourceUtils, DeltaSQLConf} import org.apache.spark.sql.delta.stats._ import org.apache.spark.sql.delta.storage.LogStore import org.apache.spark.sql.delta.util.{DeltaCommitFileProvider, JsonUtils} @@ -378,6 +378,51 @@ trait OptimisticTransactionImpl extends TransactionalWrite checkDeletionVectorFilesHaveWideBounds = false } + // An array of tuples where each tuple represents a pair (colName, newHighWatermark). + // This is collected after a write into Delta table with IDENTITY columns. If it's not + // empty, we will update the high water marks during transaction commit. Note that the same + // column can have multiple entries here if A single transaction involves multiple write + // operations. E.g. Overwrite+ReplaceWhere operation involves two phases: Phase-1 to write just + // new data and Phase-2 to delete old data. So both phases can generate tuples for a given column + // here. + protected val updatedIdentityHighWaterMarks = ArrayBuffer.empty[(String, Long)] + + // The names of columns for which we will track the IDENTITY high water marks at transaction + // writes. + protected var trackHighWaterMarks: Option[Set[String]] = None + + def setTrackHighWaterMarks(track: Set[String]): Unit = { + assert(trackHighWaterMarks.isEmpty, "The tracking set shouldn't have been set") + trackHighWaterMarks = Some(track) + } + + /** + * Records an update to the metadata that should be committed with this transaction. As this is + * called after write, it skips checking `!hasWritten`. We do not have a full protocol of what + * `updating metadata after write` should behave, as currently this is only used to update + * IDENTITY columns high water marks. As a result, it goes through all the steps needed to update + * schema BEFORE writes, except skipping the check mentioned above. Note that schema evolution + * and IDENTITY update can happen inside a single transaction so this function does not check + * we have only one metadata update in a transaction. + * + * IMPORTANT: It is the responsibility of the caller to ensure that files currently present in + * the table and written by this transaction are valid under the new metadata. + */ + private def updateMetadataAfterWrite(updatedMetadata: Metadata): Unit = { + updateMetadataInternal(updatedMetadata, ignoreDefaultProperties = false) + } + + // Called before commit to update table schema with collected IDENTITY column high water marks + // so that the change can be committed to delta log. + def precommitUpdateSchemaWithIdentityHighWaterMarks(): Unit = { + if (updatedIdentityHighWaterMarks.nonEmpty) { + val newSchema = IdentityColumn.updateSchema( + metadata.schema, updatedIdentityHighWaterMarks.toSeq) + val updatedMetadata = metadata.copy(schemaString = newSchema.json) + updateMetadataAfterWrite(updatedMetadata) + } + } + /** The set of distinct partitions that contain added files by current transaction. */ protected[delta] var partitionsAddedToOpt: Option[mutable.HashSet[Map[String, String]]] = None @@ -653,6 +698,10 @@ trait OptimisticTransactionImpl extends TransactionalWrite setNewProtocolWithFeaturesEnabledByMetadata(newMetadataTmp) } + if (isCreatingNewTable) { + IdentityColumn.logTableCreation(deltaLog, newMetadataTmp.schema) + } + newMetadataTmp = newMetadataTmp.copy(configuration = configsWithoutProtocolProps) Protocol.assertMetadataContainsNoProtocolProps(newMetadataTmp) @@ -1148,6 +1197,11 @@ trait OptimisticTransactionImpl extends TransactionalWrite // Check for internal SetTransaction conflicts and dedup. val finalActions = checkForSetTransactionConflictAndDedup(actions ++ this.actions.toSeq) + // Update schema for IDENTITY column writes if necessary. This has to be called before + // `prepareCommit` because it might change metadata and `prepareCommit` is responsible for + // converting updated metadata into a `Metadata` action. + precommitUpdateSchemaWithIdentityHighWaterMarks() + // Try to commit at the next version. var preparedActions = executionObserver.preparingCommit { diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala index c0cfc586a64..8f601047aa5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/WriteIntoDelta.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.delta.commands +import scala.collection.mutable + // scalastyle:off import.ordering.noEmptyLine import org.apache.spark.sql.delta._ import org.apache.spark.sql.delta.actions._ @@ -34,7 +36,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable -import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.metric.SQLMetric @@ -260,6 +262,24 @@ case class WriteIntoDelta( sparkSession.conf.get(DeltaSQLConf.REPLACEWHERE_DATACOLUMNS_WITH_CDF_ENABLED) && cdcExistsInRemoveOp) { var dataWithDefaultExprs = data + // Add identity columns if they are not in `data`. + // Column names for which we will track identity column high water marks. + val trackHighWaterMarks = mutable.Set.empty[String] + val topLevelOutputNames = CaseInsensitiveMap(data.schema.map(f => f.name -> f).toMap) + val selectExprs = txn.metadata.schema.map { f => + if (ColumnWithDefaultExprUtils.isIdentityColumn(f) && + !topLevelOutputNames.contains(f.name)) { + // Track high water marks for generated IDENTITY values. + trackHighWaterMarks += f.name + IdentityColumn.createIdentityColumnGenerationExprAsColumn(f) + } else { + SchemaUtils.fieldToColumn(f).alias(f.name) + } + } + if (trackHighWaterMarks.nonEmpty) { + txn.setTrackHighWaterMarks(trackHighWaterMarks.toSet) + dataWithDefaultExprs = data.select(selectExprs: _*) + } // pack new data and cdc data into an array of structs and unpack them into rows // to share values in outputCols on both branches, avoiding re-evaluating diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala b/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala index 67e5219ae2c..d71694190c4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/files/TransactionalWrite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.delta.stats.{ DeltaJobStatisticsTracker, StatisticsCollection } +import org.apache.spark.sql.util.ScalaExtensions._ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -379,12 +380,18 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl val (data, partitionSchema) = performCDCPartition(inputData) val outputPath = deltaLog.dataPath - val (queryExecution, output, generatedColumnConstraints, _) = + val (queryExecution, output, generatedColumnConstraints, trackFromData) = normalizeData(deltaLog, writeOptions, data) + // Use the track set from the transaction if set, + // otherwise use the track set from `normalizeData()`. + val trackIdentityHighWaterMarks = trackHighWaterMarks.getOrElse(trackFromData) + val partitioningColumns = getPartitioningColumns(partitionSchema, output) val committer = getCommitter(outputPath) + val (statsDataSchema, _) = getStatsSchema(output, partitionSchema) + // If Statistics Collection is enabled, then create a stats tracker that will be injected during // the FileFormatWriter.write call below and will collect per-file stats using // StatisticsCollection @@ -395,6 +402,15 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl val constraints = Constraints.getAll(metadata, spark) ++ generatedColumnConstraints ++ additionalConstraints + val identityTrackerOpt = IdentityColumn.createIdentityColumnStatsTracker( + spark, + deltaLog.newDeltaHadoopConf(), + outputPath, + metadata.schema, + statsDataSchema, + trackIdentityHighWaterMarks + ) + SQLExecution.withNewExecutionId(queryExecution, Option("deltaTransactionalWrite")) { val outputSpec = FileFormatWriter.OutputSpec( outputPath.toString, @@ -450,13 +466,20 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl partitionColumns = partitioningColumns, bucketSpec = None, statsTrackers = optionalStatsTracker.toSeq - ++ statsTrackers, + ++ statsTrackers + ++ identityTrackerOpt.toSeq, options = options) } catch { case InnerInvariantViolationException(violationException) => // Pull an InvariantViolationException up to the top level if it was the root cause. throw violationException } + statsTrackers.foreach { + case tracker: BasicWriteJobStatsTracker => + val numOutputRowsOpt = tracker.driverSideMetrics.get("numOutputRows").map(_.value) + IdentityColumn.logTableWrite(snapshot, trackIdentityHighWaterMarks, numOutputRowsOpt) + case _ => () + } } var resultFiles = @@ -490,6 +513,10 @@ trait TransactionalWrite extends DeltaLogging { self: OptimisticTransactionImpl if (resultFiles.nonEmpty && !isOptimize) registerPostCommitHook(AutoCompact) + // Record the updated high water marks to be used during transaction commit. + identityTrackerOpt.ifDefined { tracker => + updatedIdentityHighWaterMarks.appendAll(tracker.highWaterMarks.toSeq) + } resultFiles.toSeq ++ committer.changeFiles } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnIngestionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnIngestionSuite.scala new file mode 100644 index 00000000000..c4f70bf12b2 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnIngestionSuite.scala @@ -0,0 +1,290 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import java.io.PrintWriter + +import org.apache.spark.sql.delta.GeneratedAsIdentityType.{GeneratedAlways, GeneratedByDefault} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.types._ + +/** + * Identity Column test suite for ingestion, including insert-only MERGE. + */ +trait IdentityColumnIngestionSuiteBase extends IdentityColumnTestUtils { + + import testImplicits._ + + private val tblName = "identity_test" + private val tempTblName = "identity_test_temp" + private val tempCsvFileName = "test.csv" + + /** Helper function to write a single 'value' column into `sourcePath`. */ + private def setupSimpleCsvFiles(sourcePath: String, start: Int, end: Int): Unit = { + val writer = new PrintWriter(s"$sourcePath/$tempCsvFileName") + // Write header. + writer.write("value\n") + // Write values. + (start to end).foreach { v => + writer.write(s"$v\n") + } + writer.close() + } + + object IngestMode extends Enumeration { + // Ingest using data frame append v1. + val appendV1 = Value + + // Ingest using data frame append v2. + val appendV2 = Value + + // Ingest using "INSERT INTO ... VALUES". + val insertIntoValues = Value + + // Ingest using "INSERT INTO ... SELECT ...". + val insertIntoSelect = Value + + // Ingest using "INSERT OVERWRITE ... VALUES". + val insertOverwriteValues = Value + + // Ingest using "INSERT OVERWRITE ... SELECT ...". + val insertOverwriteSelect = Value + + + // Ingest using streaming query. + val streaming = Value + + // Ingest using MERGE INTO ... WHEN NOT MATCHED INSERT + val mergeInsert = Value + } + + case class IngestTestCase(start: Long, step: Long, iteration: Int, batchSize: Int) + + /** + * Helper function to test ingesting data to delta table with IDENTITY columns. + * + * @param start IDENTITY start configuration. + * @param step IDENTITY step configuration. + * @param iteration How many batch to ingest. + * @param batchSize How many rows to ingest in each batch. + * @param mode Specifies what command to use to ingest data. + */ + private def testIngestData( + start: Long, + step: Long, + iteration: Int, + batchSize: Int, + mode: IngestMode.Value): Unit = { + var highWaterMark = start - step + withTable(tblName) { + createTableWithIdColAndIntValueCol( + tblName, GeneratedAlways, startsWith = Some(start), incrementBy = Some(step)) + val deltaLog = DeltaLog.forTable(spark, TableIdentifier(tblName)) + for (iter <- 0 to iteration - 1) { + val batchStart = iter * batchSize + 1 + val batchEnd = (iter + 1) * batchSize + + // Used by data frame append v1 and append v2. + val df = (batchStart to batchEnd).toDF("value") + // Used by insertInto, insertIntoSelect, insertOverwrite, insertOverwriteSelect + val insertValues = (batchStart to batchEnd).map(v => s"($v)").mkString(",") + + mode match { + case IngestMode.appendV1 => + df.write.format("delta").mode("append").save(deltaLog.dataPath.toString) + + case IngestMode.appendV2 => + df.writeTo(tblName).append() + + case IngestMode.insertIntoValues => + val insertStmt = s"INSERT INTO $tblName(value) VALUES $insertValues;" + sql(insertStmt) + + case IngestMode.insertIntoSelect => + withTable(tempTblName) { + // Insert values into a separate table, then select into the destination table. + createTable( + tempTblName, Seq(TestColumnSpec(colName = "value", dataType = IntegerType))) + sql(s"INSERT INTO $tempTblName VALUES $insertValues") + sql(s"INSERT INTO $tblName(value) SELECT value FROM $tempTblName") + } + + case IngestMode.insertOverwriteSelect => + withTable(tempTblName) { + // Insert values into a separate table, then select into the destination table. + createTable( + tempTblName, Seq(TestColumnSpec(colName = "value", dataType = IntegerType))) + sql(s"INSERT INTO $tempTblName VALUES $insertValues") + sql(s"INSERT OVERWRITE $tblName(value) SELECT value FROM $tempTblName") + } + + case IngestMode.insertOverwriteValues => + val insertStmt = s"INSERT OVERWRITE $tblName(value) VALUES $insertValues" + sql(insertStmt) + + case IngestMode.streaming => + withTempDir { checkpointDir => + val stream = MemoryStream[Int] + val q = stream + .toDF + .toDF("value") + .writeStream + .format("delta") + .outputMode("append") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(deltaLog.dataPath.toString) + stream.addData(batchStart to batchEnd) + q.processAllAvailable() + q.stop() + } + + case IngestMode.mergeInsert => + withTable(tempTblName) { + // Insert values into a separate table, then merge into the destination table. + createTable( + tempTblName, Seq(TestColumnSpec(colName = "value", dataType = IntegerType))) + sql(s"INSERT INTO $tempTblName VALUES $insertValues") + sql( + s""" + |MERGE INTO $tblName + | USING $tempTblName ON $tblName.value = $tempTblName.value + | WHEN NOT MATCHED THEN INSERT (value) VALUES ($tempTblName.value) + |""".stripMargin) + } + + case _ => assert(false, "Unrecognized ingestion mode") + } + + val expectedRowCount = mode match { + case _@(IngestMode.insertOverwriteValues | IngestMode.insertOverwriteSelect) => + // These modes keep the row count unchanged. + batchSize + case _ => batchSize * (iter + 1) + } + + highWaterMark = validateIdentity(tblName, expectedRowCount, start, step, + batchStart, batchEnd, highWaterMark) + } + } + } + + test("append v1") { + val testCases = Seq( + IngestTestCase(1, 1, 4, 250), + IngestTestCase(1, -3, 10, 23) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, IngestMode.appendV1) + } + } + + test("append v2") { + val testCases = Seq( + IngestTestCase(100, 100, 3, 300), + IngestTestCase(Integer.MAX_VALUE.toLong + 1, -1000, 10, 23) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, IngestMode.appendV2) + } + } + + test("insert into values") { + val testCases = Seq( + IngestTestCase(100, -100, 4, 201), + IngestTestCase(Integer.MAX_VALUE.toLong + 1, 1000, 10, 37) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, IngestMode.insertIntoValues) + } + } + + test("insert into select") { + val testCases = Seq( + IngestTestCase(23, 102, 3, 77), + IngestTestCase(Integer.MAX_VALUE.toLong - 12345, 99, 8, 25) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, IngestMode.insertIntoSelect) + } + } + + test("insert overwrite values") { + val testCases = Seq( + IngestTestCase(-10, 3, 5, 30), + IngestTestCase(Integer.MIN_VALUE.toLong - 1000, -18, 2, 100) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, + IngestMode.insertOverwriteValues) + } + } + + test("insert overwrite select") { + val testCases = Seq( + IngestTestCase(-15, 20, 4, 35), + IngestTestCase(200, 50, 3, 7) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, + IngestMode.insertOverwriteSelect) + } + } + + test("streaming") { + val testCases = Seq( + IngestTestCase(-2000, 19, 5, 20), + IngestTestCase(10, 10, 4, 17) + ) + for (tc <- testCases) { + testIngestData(tc.start, tc.step, tc.iteration, tc.batchSize, IngestMode.streaming) + } + } + + test("explicit insert should not update high water mark") { + withIdentityColumnTable(GeneratedByDefault, tblName) { + val deltaLog = DeltaLog.forTable(spark, TableIdentifier(tblName)) + val schema1 = deltaLog.snapshot.metadata.schemaString + + // System generated IDENTITY value - should update schema. + sql(s"INSERT INTO $tblName(value) VALUES (1);") + val schema2 = deltaLog.snapshot.metadata.schemaString + assert(schema1 != schema2) + + // Explicitly provided IDENTITY value - should not update schema. + sql(s"INSERT INTO $tblName VALUES (1,1);") + val schema3 = deltaLog.snapshot.metadata.schemaString + assert(schema2 == schema3) + } + } +} + +class IdentityColumnIngestionScalaSuite + extends IdentityColumnIngestionSuiteBase + with ScalaDDLTestUtils + +class IdentityColumnIngestionScalaIdColumnMappingSuite + extends IdentityColumnIngestionSuiteBase + with ScalaDDLTestUtils + with DeltaColumnMappingEnableIdMode + +class IdentityColumnIngestionScalaNameColumnMappingSuite + extends IdentityColumnIngestionSuiteBase + with ScalaDDLTestUtils + with DeltaColumnMappingEnableNameMode diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnTestUtils.scala index a382ebfa12b..88df947e631 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnTestUtils.scala @@ -64,5 +64,60 @@ trait IdentityColumnTestUtils tblProperties = tblProperties ) } + + /** + * Creates and manages a simple identity column table with one other column "value" of type int + */ + protected def withIdentityColumnTable( + generatedAsIdentityType: GeneratedAsIdentityType, + tableName: String)(f: => Unit): Unit = { + withTable(tableName) { + createTableWithIdColAndIntValueCol(tableName, generatedAsIdentityType, None, None) + f + } + } + + /** + * Helper function to validate values of IDENTITY column `id` in table `tableName`. Returns the + * new high water mark. We use minValue and maxValue to filter column `value` to get the set of + * values we are checking in this batch. + */ + protected def validateIdentity( + tableName: String, + expectedRowCount: Long, + start: Long, + step: Long, + minValue: Long, + maxValue: Long, + oldHighWaterMark: Long): Long = { + // Check row count. + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName"), + Row(expectedRowCount) + ) + // Check values are unique. + checkAnswer( + sql(s"SELECT COUNT(DISTINCT id) FROM $tableName"), + Row(expectedRowCount) + ) + // Check values follow start and step configuration. + checkAnswer( + sql(s"SELECT COUNT(*) FROM $tableName WHERE (id - $start) % $step != 0"), + Row(0) + ) + // Check values generated in this batch are after previous high water mark. + checkAnswer( + sql( + s""" + |SELECT COUNT(*) FROM $tableName + | WHERE (value BETWEEN $minValue and $maxValue) + | AND ((id - $oldHighWaterMark) / $step < 0) + |""".stripMargin), + Row(0) + ) + // Update high water mark. + val func = if (step > 0) "MAX" else "MIN" + sql(s"SELECT $func(id) FROM $tableName").collect().head.getLong(0) + } }