From 280c878909a9a2d10062894b8ac30a772b44fa4a Mon Sep 17 00:00:00 2001 From: Carmen Kwan Date: Thu, 25 Apr 2024 23:41:48 +0200 Subject: [PATCH] [Spark] Add GenerateIdentityValues UDF (#2915) #### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This PR is part of https://github.com/delta-io/delta/issues/1959 In this PR, we introduce the `GenerateIdentityValues` UDF used for populating Identity Column values. The UDF is not used in Delta in this PR yet. `GenerateIdentityValues` is a simple non-deterministic UDF which keeps a counter with the user specified `start` and `step`. It counts in increments of `numPartitions` so that it can be parallelized in different tasks. ## How was this patch tested? New test suite and unit tests for the UDF. ## Does this PR introduce _any_ user-facing changes? No. --- .../sql/delta/GenerateIdentityValues.scala | 156 ++++++++++++++++ .../delta/GenerateIdentityValuesSuite.scala | 176 ++++++++++++++++++ 2 files changed, 332 insertions(+) create mode 100644 spark/src/main/scala/org/apache/spark/sql/delta/GenerateIdentityValues.scala create mode 100644 spark/src/test/scala/org/apache/spark/sql/delta/GenerateIdentityValuesSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/GenerateIdentityValues.scala b/spark/src/main/scala/org/apache/spark/sql/delta/GenerateIdentityValues.scala new file mode 100644 index 00000000000..20da84f2c14 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/GenerateIdentityValues.scala @@ -0,0 +1,156 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import com.databricks.spark.util.MetricDefinitions +import com.databricks.spark.util.TagDefinitions.TAG_OP_TYPE +import org.apache.spark.sql.delta.metering.DeltaLogging + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.{DataType, LongType} + +/** + * Returns the next generated IDENTITY column value based on the underlying + * [[PartitionIdentityValueGenerator]]. + */ +case class GenerateIdentityValues(generator: PartitionIdentityValueGenerator) + extends LeafExpression with Nondeterministic { + + override protected def initializeInternal(partitionIndex: Int): Unit = { + generator.initialize(partitionIndex) + } + + override protected def evalInternal(input: InternalRow): Long = generator.next() + + override def nullable: Boolean = false + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. + * @return an [[ExprCode]] containing the Java source code to generate the given expression + */ + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val generatorTerm = ctx.addReferenceObj("generator", generator, + classOf[PartitionIdentityValueGenerator].getName) + + ctx.addPartitionInitializationStatement(s"$generatorTerm.initialize(partitionIndex);") + ev.copy(code = code""" + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $generatorTerm.next(); + """, isNull = FalseLiteral) + } + + /** + * Returns the [[DataType]] of the result of evaluating this expression. It is + * invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). + */ + override def dataType: DataType = LongType +} + +object GenerateIdentityValues { + def apply(start: Long, step: Long, highWaterMarkOpt: Option[Long]): GenerateIdentityValues = { + new GenerateIdentityValues(PartitionIdentityValueGenerator(start, step, highWaterMarkOpt)) + } +} + +/** + * Generator of IDENTITY value for one partition. + * + * @param start The configured start value for the identity column. + * @param highWaterMarkOpt The optional high watermark for the identity value generation. If this is + * None, that means that no identity values has been generated in the past and + * we should start the identity value generation from the `start`. + * @param step IDENTITY value increment. + */ +case class PartitionIdentityValueGenerator( + start: Long, + step: Long, + highWaterMarkOpt: Option[Long]) { + + require(step != 0) + // The value generation logic requires high water mark to follow the start and step configuration. + highWaterMarkOpt.foreach(highWaterMark => require((highWaterMark - start) % step == 0)) + + private lazy val base = highWaterMarkOpt.map(Math.addExact(_, step)).getOrElse(start) + private var partitionIndex: Int = -1 + private var nextValue: Long = -1L + private var increment: Long = -1L + + + def initialize(partitionIndex: Int): Unit = { + if (this.partitionIndex < 0) { + this.partitionIndex = partitionIndex + this.nextValue = try { + Math.addExact(base, Math.multiplyExact(partitionIndex, step)) + } catch { + case e: ArithmeticException => + IdentityOverflowLogger.logOverflow() + throw e + } + // Each value is incremented by numPartitions * step from the previous value. + this.increment = try { + // Total number of partitions. In local execution case where TaskContext is not set, the + // task is executed as a single partition. + val numPartitions = Option(TaskContext.get()).map(_.numPartitions()).getOrElse(1) + Math.multiplyExact(numPartitions, step) + } catch { + case e: ArithmeticException => + IdentityOverflowLogger.logOverflow() + throw e + } + } else if (this.partitionIndex != partitionIndex) { + throw SparkException.internalError("Same PartitionIdentityValueGenerator object " + + s"initialized with two different partitionIndex [oldValue: ${this.partitionIndex}, " + + s"newValue: $partitionIndex]") + + } + } + + private def assertInitialized(): Unit = if (partitionIndex == -1) { + throw SparkException.internalError("PartitionIdentityValueGenerator is not initialized.") + } + + // Generate the next IDENTITY value. + def next(): Long = { + try { + assertInitialized() + val ret = nextValue + nextValue = Math.addExact(nextValue, increment) + ret + } catch { + case e: ArithmeticException => + IdentityOverflowLogger.logOverflow() + throw e + } + } +} + +object IdentityOverflowLogger extends DeltaLogging { + def logOverflow(): Unit = { + recordEvent( + MetricDefinitions.EVENT_TAHOE, + Map(TAG_OP_TYPE -> "delta.identityColumn.overflow") + ) + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/GenerateIdentityValuesSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/GenerateIdentityValuesSuite.scala new file mode 100644 index 00000000000..752c181fadb --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/GenerateIdentityValuesSuite.scala @@ -0,0 +1,176 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta + +import com.databricks.spark.util.Log4jUsageLogger +import org.apache.spark.sql.delta.IdentityColumn.IdentityInfo + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.{GreaterThan, If, Literal} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession + +class GenerateIdentityValuesSuite extends QueryTest with SharedSparkSession { + private val colName = "id" + + /** + * Verify the generated IDENTITY values are correct. + * + * @param df A DataFrame with a single column containing all the generated IDENTITY values. + * @param identityInfo IDENTITY information used for verification. + * @param rowCount Expected row count. + */ + private def verifyIdentityValues( + df: => DataFrame, + identityInfo: IdentityInfo, + rowCount: Long): Unit = { + + // Check row count is expected. + checkAnswer(df.select(count(col(colName))), Row(rowCount)) + + // Check there is no duplicate. + checkAnswer(df.select(count_distinct(new Column(colName))), Row(rowCount)) + + // Check every value follows start and step configuration + val condViolateConfig = s"($colName - ${identityInfo.start}) % ${identityInfo.step} != 0" + checkAnswer(df.where(condViolateConfig), Seq.empty) + + // Check every value is after high watermark OR >= start. + val highWaterMark = identityInfo.highWaterMark.getOrElse(identityInfo.start - identityInfo.step) + val condViolateHighWaterMark = s"(($colName - $highWaterMark)/${identityInfo.step}) < 0" + checkAnswer(df.where(condViolateHighWaterMark), Seq.empty) + + // When high watermark is empty, the first value should be start. + if (identityInfo.highWaterMark.isEmpty) { + val agg = if (identityInfo.step > 0) min(new Column(colName)) else max(new Column(colName)) + checkAnswer(df.select(agg), Row(identityInfo.start)) + } + } + + test("basic") { + val sizes = Seq(100, 1000, 10000) + val slices = Seq(2, 7, 15) + val starts = Seq(-3, 0, 1, 5, 43) + val steps = Seq(-3, -2, -1, 1, 2, 3) + for (size <- sizes; slice <- slices; start <- starts; step <- steps) { + val highWaterMarks = Seq(None, Some((start + 100 * step).toLong)) + val df = spark.range(1, size + 1, 1, slice).toDF(colName) + highWaterMarks.foreach { highWaterMark => + verifyIdentityValues( + df.select(new Column(GenerateIdentityValues(start, step, highWaterMark)).alias(colName)), + IdentityInfo(start, step, highWaterMark), + size + ) + } + } + } + + test("shared state") { + val size = 10000 + val slice = 7 + val start = -1 + val step = 3 + val highWaterMarks = Seq(None, Some((start + 100 * step).toLong)) + val df = spark.range(1, size + 1, 1, slice).toDF(colName) + highWaterMarks.foreach { highWaterMark => + // Create two GenerateIdentityValues expressions that share the same state. They should + // generate distinct values. + val gev = GenerateIdentityValues(start, step, highWaterMark) + val gev2 = gev.copy() + verifyIdentityValues( + df.select(new Column( + If(GreaterThan(col(colName).expr, right = Literal(10)), gev, gev2)).alias(colName)), + IdentityInfo(start, step, highWaterMark), + size + ) + } + } + + test("bigint value range") { + val size = 1000 + val slice = 32 + val start = Integer.MAX_VALUE.toLong + 1 + val step = 10 + val highWaterMark = start - step + val df = spark.range(1, size + 1, 1, slice).toDF(colName) + verifyIdentityValues( + df.select( + new Column(GenerateIdentityValues(start, step, Some(highWaterMark))).alias(colName)), + IdentityInfo(start, step, Some(highWaterMark)), + size + ) + } + + test("overflow initial value") { + val events = Log4jUsageLogger.track { + val df = spark.range(1, 10, 1, 5).toDF(colName) + .select(new Column(GenerateIdentityValues( + start = 2, + step = Long.MaxValue, + highWaterMarkOpt = Some(2 - Long.MaxValue)))) + val ex = intercept[SparkException] { + df.collect() + } + assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow")) + } + val filteredEvents = events.filter { e => + e.tags.get("opType").exists(_ == "delta.identityColumn.overflow") + } + assert(filteredEvents.size > 0) + } + + test("overflow next") { + val events = Log4jUsageLogger.track { + val df = spark.range(1, 10, 1, 5).toDF(colName) + .select(new Column(GenerateIdentityValues( + start = Long.MaxValue - 1, + step = 2, + highWaterMarkOpt = Some(Long.MaxValue - 3)))) + val ex = intercept[SparkException] { + df.collect() + } + assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow")) + } + val filteredEvents = events.filter { e => + e.tags.get("opType").exists(_ == "delta.identityColumn.overflow") + } + assert(filteredEvents.size > 0) + } + + test("invalid high water mark") { + val df = spark.range(1, 10, 1, 5).toDF(colName) + intercept[IllegalArgumentException] { + df.select(new Column(GenerateIdentityValues( + start = 1, + step = 2, + highWaterMarkOpt = Some(4))) + ).collect() + } + } + + test("invalid step") { + val df = spark.range(1, 10, 1, 5).toDF(colName) + intercept[IllegalArgumentException] { + df.select(new Column(GenerateIdentityValues( + start = 1, + step = 0, + highWaterMarkOpt = Some(4))) + ).collect() + } + } +}