Skip to content

Commit 6894ab8

Browse files
committed
Address rview comments
1 parent 47ef17e commit 6894ab8

File tree

4 files changed

+48
-24
lines changed

4 files changed

+48
-24
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5607,7 +5607,7 @@
56075607
},
56085608
"THETA_INVALID_FAMILY" : {
56095609
"message" : [
5610-
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>. Got: <value>."
5610+
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>, but got: <value>."
56115611
],
56125612
"sqlState" : "22546"
56135613
},

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation,
2424
import org.apache.spark.SparkUnsupportedOperationException
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
27-
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
27+
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
2828
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils}
2929
import org.apache.spark.sql.errors.QueryExecutionErrors
3030
import org.apache.spark.sql.internal.types.StringTypeWithCollation
@@ -59,11 +59,11 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
5959
*
6060
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information.
6161
*
62-
* @param child
62+
* @param first
6363
* child expression against which unique counting will occur
64-
* @param lgNomEntriesExpr
64+
* @param second
6565
* the log-base-2 of nomEntries decides the number of buckets for the sketch
66-
* @param familyExpr
66+
* @param third
6767
* the family of the sketch (QUICKSELECT or ALPHA)
6868
* @param mutableAggBufferOffset
6969
* offset for mutable aggregation buffer
@@ -91,24 +91,25 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
9191
since = "4.1.0")
9292
// scalastyle:on line.size.limit
9393
case class ThetaSketchAgg(
94-
child: Expression,
95-
lgNomEntriesExpr: Expression,
96-
familyExpr: Expression,
97-
override val mutableAggBufferOffset: Int,
98-
override val inputAggBufferOffset: Int)
94+
first: Expression,
95+
second: Expression,
96+
third: Expression,
97+
override val mutableAggBufferOffset: Int,
98+
override val inputAggBufferOffset: Int)
9999
extends TypedImperativeAggregate[ThetaSketchState]
100+
with TernaryLike[Expression]
100101
with ExpectsInputTypes {
101102

102103
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
103104

104105
private lazy val lgNomEntries: Int = {
105-
val lgNomEntriesInput = lgNomEntriesExpr.eval().asInstanceOf[Int]
106+
val lgNomEntriesInput = second.eval().asInstanceOf[Int]
106107
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
107108
lgNomEntriesInput
108109
}
109110

110111
private lazy val family: Family =
111-
ThetaSketchUtils.parseFamily(familyExpr.eval().asInstanceOf[UTF8String].toString, prettyName)
112+
ThetaSketchUtils.parseFamily(third.eval().asInstanceOf[UTF8String].toString, prettyName)
112113

113114
// Constructors
114115
def this(child: Expression) = {
@@ -141,12 +142,6 @@ case class ThetaSketchAgg(
141142
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg =
142143
copy(inputAggBufferOffset = newInputAggBufferOffset)
143144

144-
override protected def withNewChildrenInternal(
145-
newChildren: IndexedSeq[Expression]): ThetaSketchAgg =
146-
copy(child = newChildren(0), lgNomEntriesExpr = newChildren(1), familyExpr = newChildren(2))
147-
148-
override def children: Seq[Expression] = Seq(child, lgNomEntriesExpr, familyExpr)
149-
150145
// Overrides for TypedImperativeAggregate
151146

152147
override def prettyName: String = "theta_sketch_agg"
@@ -200,7 +195,7 @@ case class ThetaSketchAgg(
200195
*/
201196
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = {
202197
// Return early for null values.
203-
val v = child.eval(input)
198+
val v = first.eval(input)
204199
if (v == null) return updateBuffer
205200

206201
// Initialized buffer should be UpdatableSketchBuffer, else error out.
@@ -210,7 +205,7 @@ case class ThetaSketchAgg(
210205
}
211206

212207
// Handle the different data types for sketch updates.
213-
child.dataType match {
208+
first.dataType match {
214209
case ArrayType(IntegerType, _) =>
215210
val arr = v.asInstanceOf[ArrayData].toIntArray()
216211
sketch.update(arr)
@@ -237,7 +232,7 @@ case class ThetaSketchAgg(
237232
case _ =>
238233
throw new SparkUnsupportedOperationException(
239234
errorClass = "_LEGACY_ERROR_TEMP_3121",
240-
messageParameters = Map("dataType" -> child.dataType.toString))
235+
messageParameters = Map("dataType" -> first.dataType.toString))
241236
}
242237

243238
UpdatableSketchBuffer(sketch)
@@ -313,6 +308,10 @@ case class ThetaSketchAgg(
313308
this.createAggregationBuffer()
314309
}
315310
}
311+
312+
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression,
313+
newThird: Expression): Expression = copy(newFirst, newSecond,
314+
newThird, mutableAggBufferOffset, inputAggBufferOffset)
316315
}
317316

318317
/**
@@ -355,7 +354,7 @@ case class ThetaUnionAgg(
355354

356355
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
357356

358-
lazy val lgNomEntries: Int = {
357+
private lazy val lgNomEntries: Int = {
359358
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
360359
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
361360
lgNomEntriesInput

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ThetaSketchUtils.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
package org.apache.spark.sql.catalyst.util
1919

2020
import java.util.Locale
21-
2221
import org.apache.datasketches.common.{Family, SketchesArgumentException}
2322
import org.apache.datasketches.memory.{Memory, MemoryBoundsException}
2423
import org.apache.datasketches.theta.CompactSketch
25-
2624
import org.apache.spark.sql.errors.QueryExecutionErrors
2725

2826

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ThetaSketchUtilsSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ class ThetaSketchUtilsSuite extends SparkFunSuite with SQLHelper {
3333
}
3434
}
3535

36+
test("parseFamily: accepts valid family names") {
37+
// Test valid family names (case insensitive)
38+
val validFamilies = Seq(
39+
("QUICKSELECT", ThetaSketchUtils.FAMILY_QUICKSELECT),
40+
("quickselect", ThetaSketchUtils.FAMILY_QUICKSELECT),
41+
("QuickSelect", ThetaSketchUtils.FAMILY_QUICKSELECT),
42+
("ALPHA", ThetaSketchUtils.FAMILY_ALPHA),
43+
("alpha", ThetaSketchUtils.FAMILY_ALPHA),
44+
("Alpha", ThetaSketchUtils.FAMILY_ALPHA)
45+
)
46+
47+
validFamilies.foreach { case (input, expectedFamily) =>
48+
val result = ThetaSketchUtils.parseFamily(input, "test_function")
49+
assert(result.toString == expectedFamily)
50+
}
51+
52+
53+
val invalidFamilyName = "invalid"
54+
checkError(
55+
exception = intercept[SparkRuntimeException] {
56+
ThetaSketchUtils.parseFamily(invalidFamilyName, "test_function")
57+
},
58+
condition = "THETA_INVALID_FAMILY",
59+
parameters = Map("function" -> "`test_function`",
60+
"value" -> "'invalid'",
61+
"validFamilies" -> "`QUICKSELECT`, `ALPHA`") )
62+
}
3663

3764
test("checkLgNomLongs: throws exception for values below minimum") {
3865
val invalidValues = Seq(ThetaSketchUtils.MIN_LG_NOM_LONGS - 1, 0, -5)

0 commit comments

Comments
 (0)