From 5453963901f2902981b60f2b8dfca726fff26b2c Mon Sep 17 00:00:00 2001 From: Carmen Kwan Date: Thu, 19 Dec 2024 20:36:20 +0100 Subject: [PATCH] fix roundToNext --- .../spark/sql/delta/IdentityColumn.scala | 13 +++++++--- .../sql/delta/IdentityColumnSyncSuite.scala | 24 +++++++++++++------ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala b/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala index 3d829274ab..8e400cacdc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/IdentityColumn.scala @@ -157,13 +157,20 @@ object IdentityColumn extends DeltaLogging { value } else { // An identity value follows the formula start + step * n. So n = (value - start) / step. + // Where n is a non-negative integer if the value respects the start. // Since the value doesn't follow this formula, we need to ceil n. // corrected value = start + step * ceil(n). // However, we can't cast to Double for division because it's only accurate up to 54 bits. - // Instead, we will do a floored division and add 1 if it's a positive step or -1 if - // it is a negative step. + // Instead, we will do a floored division and add 1. // start + step * ((value - start) / step + 1) - val stepMultiple = (valueOffset / step) + Math.signum(step).toInt + val quotient = valueOffset / step + // `valueOffset` will have the same sign as `step` if `value` respects the start. + val stepMultiple = if (Math.signum(valueOffset) == Math.signum(step)) { + Math.addExact(quotient, 1L) + } else { + // Don't add one. Otherwise, we end up rounding 2 values up, which may skip the start. + quotient + } Math.addExact( start, Math.multiplyExact(step, stepMultiple) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnSyncSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnSyncSuite.scala index 8c939d9f7f..2c43530374 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnSyncSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/IdentityColumnSyncSuite.scala @@ -583,19 +583,29 @@ trait IdentityColumnSyncSuiteBase val negStart = -7L val posLargeStart = Long.MaxValue - 10000 val negLargeStart = Long.MinValue + 10000 - for (largeStart <- Seq(posStart, negStart, posLargeStart, negLargeStart)) { + for (start <- Seq(posStart, negStart, posLargeStart, negLargeStart)) { + assert(IdentityColumn.roundToNext(start = start, step = 3L, value = start) === start) assert(IdentityColumn.roundToNext( - start = largeStart, step = 3L, value = largeStart) === largeStart) + start = start, step = 3L, value = start + 5L) === start + 6L) assert(IdentityColumn.roundToNext( - start = largeStart, step = 3L, value = largeStart + 5L) === largeStart + 6L) + start = start, step = 3L, value = start + 6L) === start + 6L) assert(IdentityColumn.roundToNext( - start = largeStart, step = 3L, value = largeStart + 6L) === largeStart + 6L) + start = start, step = 3L, value = start - 5L) === start - 3L) // bad watermark assert(IdentityColumn.roundToNext( - start = largeStart, step = -3L, value = largeStart) === largeStart) + start = start, step = 3L, value = start - 7L) === start - 6L) // bad watermark assert(IdentityColumn.roundToNext( - start = largeStart, step = -3L, value = largeStart - 5L) === largeStart - 6L) + start = start, step = 3L, value = start - 6L) === start - 6L) // bad watermark + assert(IdentityColumn.roundToNext(start = start, step = -3L, value = start) === start) assert(IdentityColumn.roundToNext( - start = largeStart, step = -3L, value = largeStart - 6L) === largeStart - 6L) + start = start, step = -3L, value = start - 5L) === start - 6L) + assert(IdentityColumn.roundToNext( + start = start, step = -3L, value = start - 6L) === start - 6L) + assert(IdentityColumn.roundToNext( + start = start, step = -3L, value = start + 5L) === start + 3L) // bad watermark + assert(IdentityColumn.roundToNext( + start = start, step = -3L, value = start + 7L) === start + 6L) // bad watermark + assert(IdentityColumn.roundToNext( + start = start, step = -3L, value = start + 6L) === start + 6L) // bad watermark } } }