-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Spark] Add GenerateIdentityValues UDF (#2915)
#### Which Delta project/connector is this regarding? - [x] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description This PR is part of #1959 In this PR, we introduce the `GenerateIdentityValues` UDF used for populating Identity Column values. The UDF is not used in Delta in this PR yet. `GenerateIdentityValues` is a simple non-deterministic UDF which keeps a counter with the user specified `start` and `step`. It counts in increments of `numPartitions` so that it can be parallelized in different tasks. ## How was this patch tested? New test suite and unit tests for the UDF. ## Does this PR introduce _any_ user-facing changes? No.
- Loading branch information
Showing
2 changed files
with
332 additions
and
0 deletions.
There are no files selected for viewing
156 changes: 156 additions & 0 deletions
156
spark/src/main/scala/org/apache/spark/sql/delta/GenerateIdentityValues.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
/* | ||
* Copyright (2021) The Delta Lake Project Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.delta | ||
|
||
import com.databricks.spark.util.MetricDefinitions | ||
import com.databricks.spark.util.TagDefinitions.TAG_OP_TYPE | ||
import org.apache.spark.sql.delta.metering.DeltaLogging | ||
|
||
import org.apache.spark.{SparkException, TaskContext} | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
import org.apache.spark.sql.types.{DataType, LongType} | ||
|
||
/** | ||
* Returns the next generated IDENTITY column value based on the underlying | ||
* [[PartitionIdentityValueGenerator]]. | ||
*/ | ||
case class GenerateIdentityValues(generator: PartitionIdentityValueGenerator) | ||
extends LeafExpression with Nondeterministic { | ||
|
||
override protected def initializeInternal(partitionIndex: Int): Unit = { | ||
generator.initialize(partitionIndex) | ||
} | ||
|
||
override protected def evalInternal(input: InternalRow): Long = generator.next() | ||
|
||
override def nullable: Boolean = false | ||
|
||
/** | ||
* Returns Java source code that can be compiled to evaluate this expression. | ||
* The default behavior is to call the eval method of the expression. Concrete expression | ||
* implementations should override this to do actual code generation. | ||
* | ||
* @param ctx a [[CodegenContext]] | ||
* @param ev an [[ExprCode]] with unique terms. | ||
* @return an [[ExprCode]] containing the Java source code to generate the given expression | ||
*/ | ||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
val generatorTerm = ctx.addReferenceObj("generator", generator, | ||
classOf[PartitionIdentityValueGenerator].getName) | ||
|
||
ctx.addPartitionInitializationStatement(s"$generatorTerm.initialize(partitionIndex);") | ||
ev.copy(code = code""" | ||
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $generatorTerm.next(); | ||
""", isNull = FalseLiteral) | ||
} | ||
|
||
/** | ||
* Returns the [[DataType]] of the result of evaluating this expression. It is | ||
* invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false). | ||
*/ | ||
override def dataType: DataType = LongType | ||
} | ||
|
||
object GenerateIdentityValues { | ||
def apply(start: Long, step: Long, highWaterMarkOpt: Option[Long]): GenerateIdentityValues = { | ||
new GenerateIdentityValues(PartitionIdentityValueGenerator(start, step, highWaterMarkOpt)) | ||
} | ||
} | ||
|
||
/** | ||
* Generator of IDENTITY value for one partition. | ||
* | ||
* @param start The configured start value for the identity column. | ||
* @param highWaterMarkOpt The optional high watermark for the identity value generation. If this is | ||
* None, that means that no identity values has been generated in the past and | ||
* we should start the identity value generation from the `start`. | ||
* @param step IDENTITY value increment. | ||
*/ | ||
case class PartitionIdentityValueGenerator( | ||
start: Long, | ||
step: Long, | ||
highWaterMarkOpt: Option[Long]) { | ||
|
||
require(step != 0) | ||
// The value generation logic requires high water mark to follow the start and step configuration. | ||
highWaterMarkOpt.foreach(highWaterMark => require((highWaterMark - start) % step == 0)) | ||
|
||
private lazy val base = highWaterMarkOpt.map(Math.addExact(_, step)).getOrElse(start) | ||
private var partitionIndex: Int = -1 | ||
private var nextValue: Long = -1L | ||
private var increment: Long = -1L | ||
|
||
|
||
def initialize(partitionIndex: Int): Unit = { | ||
if (this.partitionIndex < 0) { | ||
this.partitionIndex = partitionIndex | ||
this.nextValue = try { | ||
Math.addExact(base, Math.multiplyExact(partitionIndex, step)) | ||
} catch { | ||
case e: ArithmeticException => | ||
IdentityOverflowLogger.logOverflow() | ||
throw e | ||
} | ||
// Each value is incremented by numPartitions * step from the previous value. | ||
this.increment = try { | ||
// Total number of partitions. In local execution case where TaskContext is not set, the | ||
// task is executed as a single partition. | ||
val numPartitions = Option(TaskContext.get()).map(_.numPartitions()).getOrElse(1) | ||
Math.multiplyExact(numPartitions, step) | ||
} catch { | ||
case e: ArithmeticException => | ||
IdentityOverflowLogger.logOverflow() | ||
throw e | ||
} | ||
} else if (this.partitionIndex != partitionIndex) { | ||
throw SparkException.internalError("Same PartitionIdentityValueGenerator object " + | ||
s"initialized with two different partitionIndex [oldValue: ${this.partitionIndex}, " + | ||
s"newValue: $partitionIndex]") | ||
|
||
} | ||
} | ||
|
||
private def assertInitialized(): Unit = if (partitionIndex == -1) { | ||
throw SparkException.internalError("PartitionIdentityValueGenerator is not initialized.") | ||
} | ||
|
||
// Generate the next IDENTITY value. | ||
def next(): Long = { | ||
try { | ||
assertInitialized() | ||
val ret = nextValue | ||
nextValue = Math.addExact(nextValue, increment) | ||
ret | ||
} catch { | ||
case e: ArithmeticException => | ||
IdentityOverflowLogger.logOverflow() | ||
throw e | ||
} | ||
} | ||
} | ||
|
||
object IdentityOverflowLogger extends DeltaLogging { | ||
def logOverflow(): Unit = { | ||
recordEvent( | ||
MetricDefinitions.EVENT_TAHOE, | ||
Map(TAG_OP_TYPE -> "delta.identityColumn.overflow") | ||
) | ||
} | ||
} |
176 changes: 176 additions & 0 deletions
176
spark/src/test/scala/org/apache/spark/sql/delta/GenerateIdentityValuesSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
/* | ||
* Copyright (2021) The Delta Lake Project Authors. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.delta | ||
|
||
import com.databricks.spark.util.Log4jUsageLogger | ||
import org.apache.spark.sql.delta.IdentityColumn.IdentityInfo | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row} | ||
import org.apache.spark.sql.catalyst.expressions.{GreaterThan, If, Literal} | ||
import org.apache.spark.sql.functions._ | ||
import org.apache.spark.sql.test.SharedSparkSession | ||
|
||
class GenerateIdentityValuesSuite extends QueryTest with SharedSparkSession { | ||
private val colName = "id" | ||
|
||
/** | ||
* Verify the generated IDENTITY values are correct. | ||
* | ||
* @param df A DataFrame with a single column containing all the generated IDENTITY values. | ||
* @param identityInfo IDENTITY information used for verification. | ||
* @param rowCount Expected row count. | ||
*/ | ||
private def verifyIdentityValues( | ||
df: => DataFrame, | ||
identityInfo: IdentityInfo, | ||
rowCount: Long): Unit = { | ||
|
||
// Check row count is expected. | ||
checkAnswer(df.select(count(col(colName))), Row(rowCount)) | ||
|
||
// Check there is no duplicate. | ||
checkAnswer(df.select(count_distinct(new Column(colName))), Row(rowCount)) | ||
|
||
// Check every value follows start and step configuration | ||
val condViolateConfig = s"($colName - ${identityInfo.start}) % ${identityInfo.step} != 0" | ||
checkAnswer(df.where(condViolateConfig), Seq.empty) | ||
|
||
// Check every value is after high watermark OR >= start. | ||
val highWaterMark = identityInfo.highWaterMark.getOrElse(identityInfo.start - identityInfo.step) | ||
val condViolateHighWaterMark = s"(($colName - $highWaterMark)/${identityInfo.step}) < 0" | ||
checkAnswer(df.where(condViolateHighWaterMark), Seq.empty) | ||
|
||
// When high watermark is empty, the first value should be start. | ||
if (identityInfo.highWaterMark.isEmpty) { | ||
val agg = if (identityInfo.step > 0) min(new Column(colName)) else max(new Column(colName)) | ||
checkAnswer(df.select(agg), Row(identityInfo.start)) | ||
} | ||
} | ||
|
||
test("basic") { | ||
val sizes = Seq(100, 1000, 10000) | ||
val slices = Seq(2, 7, 15) | ||
val starts = Seq(-3, 0, 1, 5, 43) | ||
val steps = Seq(-3, -2, -1, 1, 2, 3) | ||
for (size <- sizes; slice <- slices; start <- starts; step <- steps) { | ||
val highWaterMarks = Seq(None, Some((start + 100 * step).toLong)) | ||
val df = spark.range(1, size + 1, 1, slice).toDF(colName) | ||
highWaterMarks.foreach { highWaterMark => | ||
verifyIdentityValues( | ||
df.select(new Column(GenerateIdentityValues(start, step, highWaterMark)).alias(colName)), | ||
IdentityInfo(start, step, highWaterMark), | ||
size | ||
) | ||
} | ||
} | ||
} | ||
|
||
test("shared state") { | ||
val size = 10000 | ||
val slice = 7 | ||
val start = -1 | ||
val step = 3 | ||
val highWaterMarks = Seq(None, Some((start + 100 * step).toLong)) | ||
val df = spark.range(1, size + 1, 1, slice).toDF(colName) | ||
highWaterMarks.foreach { highWaterMark => | ||
// Create two GenerateIdentityValues expressions that share the same state. They should | ||
// generate distinct values. | ||
val gev = GenerateIdentityValues(start, step, highWaterMark) | ||
val gev2 = gev.copy() | ||
verifyIdentityValues( | ||
df.select(new Column( | ||
If(GreaterThan(col(colName).expr, right = Literal(10)), gev, gev2)).alias(colName)), | ||
IdentityInfo(start, step, highWaterMark), | ||
size | ||
) | ||
} | ||
} | ||
|
||
test("bigint value range") { | ||
val size = 1000 | ||
val slice = 32 | ||
val start = Integer.MAX_VALUE.toLong + 1 | ||
val step = 10 | ||
val highWaterMark = start - step | ||
val df = spark.range(1, size + 1, 1, slice).toDF(colName) | ||
verifyIdentityValues( | ||
df.select( | ||
new Column(GenerateIdentityValues(start, step, Some(highWaterMark))).alias(colName)), | ||
IdentityInfo(start, step, Some(highWaterMark)), | ||
size | ||
) | ||
} | ||
|
||
test("overflow initial value") { | ||
val events = Log4jUsageLogger.track { | ||
val df = spark.range(1, 10, 1, 5).toDF(colName) | ||
.select(new Column(GenerateIdentityValues( | ||
start = 2, | ||
step = Long.MaxValue, | ||
highWaterMarkOpt = Some(2 - Long.MaxValue)))) | ||
val ex = intercept[SparkException] { | ||
df.collect() | ||
} | ||
assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow")) | ||
} | ||
val filteredEvents = events.filter { e => | ||
e.tags.get("opType").exists(_ == "delta.identityColumn.overflow") | ||
} | ||
assert(filteredEvents.size > 0) | ||
} | ||
|
||
test("overflow next") { | ||
val events = Log4jUsageLogger.track { | ||
val df = spark.range(1, 10, 1, 5).toDF(colName) | ||
.select(new Column(GenerateIdentityValues( | ||
start = Long.MaxValue - 1, | ||
step = 2, | ||
highWaterMarkOpt = Some(Long.MaxValue - 3)))) | ||
val ex = intercept[SparkException] { | ||
df.collect() | ||
} | ||
assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow")) | ||
} | ||
val filteredEvents = events.filter { e => | ||
e.tags.get("opType").exists(_ == "delta.identityColumn.overflow") | ||
} | ||
assert(filteredEvents.size > 0) | ||
} | ||
|
||
test("invalid high water mark") { | ||
val df = spark.range(1, 10, 1, 5).toDF(colName) | ||
intercept[IllegalArgumentException] { | ||
df.select(new Column(GenerateIdentityValues( | ||
start = 1, | ||
step = 2, | ||
highWaterMarkOpt = Some(4))) | ||
).collect() | ||
} | ||
} | ||
|
||
test("invalid step") { | ||
val df = spark.range(1, 10, 1, 5).toDF(colName) | ||
intercept[IllegalArgumentException] { | ||
df.select(new Column(GenerateIdentityValues( | ||
start = 1, | ||
step = 0, | ||
highWaterMarkOpt = Some(4))) | ||
).collect() | ||
} | ||
} | ||
} |