Skip to content

Commit

Permalink
[Spark] Add GenerateIdentityValues UDF (#2915)
Browse files Browse the repository at this point in the history
#### Which Delta project/connector is this regarding?

- [x] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

This PR is part of #1959

In this PR, we introduce the `GenerateIdentityValues` UDF used for
populating Identity Column values. The UDF is not used in Delta in this
PR yet.

`GenerateIdentityValues` is a simple non-deterministic UDF which keeps a
counter with the user specified `start` and `step`. It counts in
increments of `numPartitions` so that it can be parallelized in
different tasks.

## How was this patch tested?
New test suite and unit tests for the UDF.

## Does this PR introduce _any_ user-facing changes?
No.
  • Loading branch information
c27kwan authored Apr 25, 2024
1 parent 42f09bd commit 280c878
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import com.databricks.spark.util.MetricDefinitions
import com.databricks.spark.util.TagDefinitions.TAG_OP_TYPE
import org.apache.spark.sql.delta.metering.DeltaLogging

import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, LongType}

/**
* Returns the next generated IDENTITY column value based on the underlying
* [[PartitionIdentityValueGenerator]].
*/
case class GenerateIdentityValues(generator: PartitionIdentityValueGenerator)
extends LeafExpression with Nondeterministic {

override protected def initializeInternal(partitionIndex: Int): Unit = {
generator.initialize(partitionIndex)
}

override protected def evalInternal(input: InternalRow): Long = generator.next()

override def nullable: Boolean = false

/**
* Returns Java source code that can be compiled to evaluate this expression.
* The default behavior is to call the eval method of the expression. Concrete expression
* implementations should override this to do actual code generation.
*
* @param ctx a [[CodegenContext]]
* @param ev an [[ExprCode]] with unique terms.
* @return an [[ExprCode]] containing the Java source code to generate the given expression
*/
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val generatorTerm = ctx.addReferenceObj("generator", generator,
classOf[PartitionIdentityValueGenerator].getName)

ctx.addPartitionInitializationStatement(s"$generatorTerm.initialize(partitionIndex);")
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $generatorTerm.next();
""", isNull = FalseLiteral)
}

/**
* Returns the [[DataType]] of the result of evaluating this expression. It is
* invalid to query the dataType of an unresolved expression (i.e., when `resolved` == false).
*/
override def dataType: DataType = LongType
}

object GenerateIdentityValues {
def apply(start: Long, step: Long, highWaterMarkOpt: Option[Long]): GenerateIdentityValues = {
new GenerateIdentityValues(PartitionIdentityValueGenerator(start, step, highWaterMarkOpt))
}
}

/**
* Generator of IDENTITY value for one partition.
*
* @param start The configured start value for the identity column.
* @param highWaterMarkOpt The optional high watermark for the identity value generation. If this is
* None, that means that no identity values has been generated in the past and
* we should start the identity value generation from the `start`.
* @param step IDENTITY value increment.
*/
case class PartitionIdentityValueGenerator(
start: Long,
step: Long,
highWaterMarkOpt: Option[Long]) {

require(step != 0)
// The value generation logic requires high water mark to follow the start and step configuration.
highWaterMarkOpt.foreach(highWaterMark => require((highWaterMark - start) % step == 0))

private lazy val base = highWaterMarkOpt.map(Math.addExact(_, step)).getOrElse(start)
private var partitionIndex: Int = -1
private var nextValue: Long = -1L
private var increment: Long = -1L


def initialize(partitionIndex: Int): Unit = {
if (this.partitionIndex < 0) {
this.partitionIndex = partitionIndex
this.nextValue = try {
Math.addExact(base, Math.multiplyExact(partitionIndex, step))
} catch {
case e: ArithmeticException =>
IdentityOverflowLogger.logOverflow()
throw e
}
// Each value is incremented by numPartitions * step from the previous value.
this.increment = try {
// Total number of partitions. In local execution case where TaskContext is not set, the
// task is executed as a single partition.
val numPartitions = Option(TaskContext.get()).map(_.numPartitions()).getOrElse(1)
Math.multiplyExact(numPartitions, step)
} catch {
case e: ArithmeticException =>
IdentityOverflowLogger.logOverflow()
throw e
}
} else if (this.partitionIndex != partitionIndex) {
throw SparkException.internalError("Same PartitionIdentityValueGenerator object " +
s"initialized with two different partitionIndex [oldValue: ${this.partitionIndex}, " +
s"newValue: $partitionIndex]")

}
}

private def assertInitialized(): Unit = if (partitionIndex == -1) {
throw SparkException.internalError("PartitionIdentityValueGenerator is not initialized.")
}

// Generate the next IDENTITY value.
def next(): Long = {
try {
assertInitialized()
val ret = nextValue
nextValue = Math.addExact(nextValue, increment)
ret
} catch {
case e: ArithmeticException =>
IdentityOverflowLogger.logOverflow()
throw e
}
}
}

object IdentityOverflowLogger extends DeltaLogging {
def logOverflow(): Unit = {
recordEvent(
MetricDefinitions.EVENT_TAHOE,
Map(TAG_OP_TYPE -> "delta.identityColumn.overflow")
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta

import com.databricks.spark.util.Log4jUsageLogger
import org.apache.spark.sql.delta.IdentityColumn.IdentityInfo

import org.apache.spark.SparkException
import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{GreaterThan, If, Literal}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSparkSession

class GenerateIdentityValuesSuite extends QueryTest with SharedSparkSession {
private val colName = "id"

/**
* Verify the generated IDENTITY values are correct.
*
* @param df A DataFrame with a single column containing all the generated IDENTITY values.
* @param identityInfo IDENTITY information used for verification.
* @param rowCount Expected row count.
*/
private def verifyIdentityValues(
df: => DataFrame,
identityInfo: IdentityInfo,
rowCount: Long): Unit = {

// Check row count is expected.
checkAnswer(df.select(count(col(colName))), Row(rowCount))

// Check there is no duplicate.
checkAnswer(df.select(count_distinct(new Column(colName))), Row(rowCount))

// Check every value follows start and step configuration
val condViolateConfig = s"($colName - ${identityInfo.start}) % ${identityInfo.step} != 0"
checkAnswer(df.where(condViolateConfig), Seq.empty)

// Check every value is after high watermark OR >= start.
val highWaterMark = identityInfo.highWaterMark.getOrElse(identityInfo.start - identityInfo.step)
val condViolateHighWaterMark = s"(($colName - $highWaterMark)/${identityInfo.step}) < 0"
checkAnswer(df.where(condViolateHighWaterMark), Seq.empty)

// When high watermark is empty, the first value should be start.
if (identityInfo.highWaterMark.isEmpty) {
val agg = if (identityInfo.step > 0) min(new Column(colName)) else max(new Column(colName))
checkAnswer(df.select(agg), Row(identityInfo.start))
}
}

test("basic") {
val sizes = Seq(100, 1000, 10000)
val slices = Seq(2, 7, 15)
val starts = Seq(-3, 0, 1, 5, 43)
val steps = Seq(-3, -2, -1, 1, 2, 3)
for (size <- sizes; slice <- slices; start <- starts; step <- steps) {
val highWaterMarks = Seq(None, Some((start + 100 * step).toLong))
val df = spark.range(1, size + 1, 1, slice).toDF(colName)
highWaterMarks.foreach { highWaterMark =>
verifyIdentityValues(
df.select(new Column(GenerateIdentityValues(start, step, highWaterMark)).alias(colName)),
IdentityInfo(start, step, highWaterMark),
size
)
}
}
}

test("shared state") {
val size = 10000
val slice = 7
val start = -1
val step = 3
val highWaterMarks = Seq(None, Some((start + 100 * step).toLong))
val df = spark.range(1, size + 1, 1, slice).toDF(colName)
highWaterMarks.foreach { highWaterMark =>
// Create two GenerateIdentityValues expressions that share the same state. They should
// generate distinct values.
val gev = GenerateIdentityValues(start, step, highWaterMark)
val gev2 = gev.copy()
verifyIdentityValues(
df.select(new Column(
If(GreaterThan(col(colName).expr, right = Literal(10)), gev, gev2)).alias(colName)),
IdentityInfo(start, step, highWaterMark),
size
)
}
}

test("bigint value range") {
val size = 1000
val slice = 32
val start = Integer.MAX_VALUE.toLong + 1
val step = 10
val highWaterMark = start - step
val df = spark.range(1, size + 1, 1, slice).toDF(colName)
verifyIdentityValues(
df.select(
new Column(GenerateIdentityValues(start, step, Some(highWaterMark))).alias(colName)),
IdentityInfo(start, step, Some(highWaterMark)),
size
)
}

test("overflow initial value") {
val events = Log4jUsageLogger.track {
val df = spark.range(1, 10, 1, 5).toDF(colName)
.select(new Column(GenerateIdentityValues(
start = 2,
step = Long.MaxValue,
highWaterMarkOpt = Some(2 - Long.MaxValue))))
val ex = intercept[SparkException] {
df.collect()
}
assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow"))
}
val filteredEvents = events.filter { e =>
e.tags.get("opType").exists(_ == "delta.identityColumn.overflow")
}
assert(filteredEvents.size > 0)
}

test("overflow next") {
val events = Log4jUsageLogger.track {
val df = spark.range(1, 10, 1, 5).toDF(colName)
.select(new Column(GenerateIdentityValues(
start = Long.MaxValue - 1,
step = 2,
highWaterMarkOpt = Some(Long.MaxValue - 3))))
val ex = intercept[SparkException] {
df.collect()
}
assert(ex.getMessage.contains("java.lang.ArithmeticException: long overflow"))
}
val filteredEvents = events.filter { e =>
e.tags.get("opType").exists(_ == "delta.identityColumn.overflow")
}
assert(filteredEvents.size > 0)
}

test("invalid high water mark") {
val df = spark.range(1, 10, 1, 5).toDF(colName)
intercept[IllegalArgumentException] {
df.select(new Column(GenerateIdentityValues(
start = 1,
step = 2,
highWaterMarkOpt = Some(4)))
).collect()
}
}

test("invalid step") {
val df = spark.range(1, 10, 1, 5).toDF(colName)
intercept[IllegalArgumentException] {
df.select(new Column(GenerateIdentityValues(
start = 1,
step = 0,
highWaterMarkOpt = Some(4)))
).collect()
}
}
}

0 comments on commit 280c878

Please sign in to comment.