@@ -24,7 +24,7 @@ import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation,
2424import org .apache .spark .SparkUnsupportedOperationException
2525import org .apache .spark .sql .catalyst .InternalRow
2626import org .apache .spark .sql .catalyst .expressions .{ExpectsInputTypes , Expression , ExpressionDescription , Literal }
27- import org .apache .spark .sql .catalyst .trees .{BinaryLike , UnaryLike }
27+ import org .apache .spark .sql .catalyst .trees .{BinaryLike , TernaryLike , UnaryLike }
2828import org .apache .spark .sql .catalyst .util .{ArrayData , CollationFactory , ThetaSketchUtils }
2929import org .apache .spark .sql .errors .QueryExecutionErrors
3030import org .apache .spark .sql .internal .types .StringTypeWithCollation
@@ -59,11 +59,11 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
5959 *
6060 * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html ]] for more information.
6161 *
62- * @param child
62+ * @param first
6363 * child expression against which unique counting will occur
64- * @param lgNomEntriesExpr
64+ * @param second
6565 * the log-base-2 of nomEntries decides the number of buckets for the sketch
66- * @param familyExpr
66+ * @param third
6767 * the family of the sketch (QUICKSELECT or ALPHA)
6868 * @param mutableAggBufferOffset
6969 * offset for mutable aggregation buffer
@@ -91,24 +91,25 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
9191 since = " 4.1.0" )
9292// scalastyle:on line.size.limit
9393case class ThetaSketchAgg (
94- child : Expression ,
95- lgNomEntriesExpr : Expression ,
96- familyExpr : Expression ,
97- override val mutableAggBufferOffset : Int ,
98- override val inputAggBufferOffset : Int )
94+ first : Expression ,
95+ second : Expression ,
96+ third : Expression ,
97+ override val mutableAggBufferOffset : Int ,
98+ override val inputAggBufferOffset : Int )
9999 extends TypedImperativeAggregate [ThetaSketchState ]
100+ with TernaryLike [Expression ]
100101 with ExpectsInputTypes {
101102
102103 // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
103104
104105 private lazy val lgNomEntries : Int = {
105- val lgNomEntriesInput = lgNomEntriesExpr .eval().asInstanceOf [Int ]
106+ val lgNomEntriesInput = second .eval().asInstanceOf [Int ]
106107 ThetaSketchUtils .checkLgNomLongs(lgNomEntriesInput, prettyName)
107108 lgNomEntriesInput
108109 }
109110
110111 private lazy val family : Family =
111- ThetaSketchUtils .parseFamily(familyExpr .eval().asInstanceOf [UTF8String ].toString, prettyName)
112+ ThetaSketchUtils .parseFamily(third .eval().asInstanceOf [UTF8String ].toString, prettyName)
112113
113114 // Constructors
114115 def this (child : Expression ) = {
@@ -141,12 +142,6 @@ case class ThetaSketchAgg(
141142 override def withNewInputAggBufferOffset (newInputAggBufferOffset : Int ): ThetaSketchAgg =
142143 copy(inputAggBufferOffset = newInputAggBufferOffset)
143144
144- override protected def withNewChildrenInternal (
145- newChildren : IndexedSeq [Expression ]): ThetaSketchAgg =
146- copy(child = newChildren(0 ), lgNomEntriesExpr = newChildren(1 ), familyExpr = newChildren(2 ))
147-
148- override def children : Seq [Expression ] = Seq (child, lgNomEntriesExpr, familyExpr)
149-
150145 // Overrides for TypedImperativeAggregate
151146
152147 override def prettyName : String = " theta_sketch_agg"
@@ -200,7 +195,7 @@ case class ThetaSketchAgg(
200195 */
201196 override def update (updateBuffer : ThetaSketchState , input : InternalRow ): ThetaSketchState = {
202197 // Return early for null values.
203- val v = child .eval(input)
198+ val v = first .eval(input)
204199 if (v == null ) return updateBuffer
205200
206201 // Initialized buffer should be UpdatableSketchBuffer, else error out.
@@ -210,7 +205,7 @@ case class ThetaSketchAgg(
210205 }
211206
212207 // Handle the different data types for sketch updates.
213- child .dataType match {
208+ first .dataType match {
214209 case ArrayType (IntegerType , _) =>
215210 val arr = v.asInstanceOf [ArrayData ].toIntArray()
216211 sketch.update(arr)
@@ -237,7 +232,7 @@ case class ThetaSketchAgg(
237232 case _ =>
238233 throw new SparkUnsupportedOperationException (
239234 errorClass = " _LEGACY_ERROR_TEMP_3121" ,
240- messageParameters = Map (" dataType" -> child .dataType.toString))
235+ messageParameters = Map (" dataType" -> first .dataType.toString))
241236 }
242237
243238 UpdatableSketchBuffer (sketch)
@@ -313,6 +308,10 @@ case class ThetaSketchAgg(
313308 this .createAggregationBuffer()
314309 }
315310 }
311+
312+ override protected def withNewChildrenInternal (newFirst : Expression , newSecond : Expression ,
313+ newThird : Expression ): Expression = copy(newFirst, newSecond,
314+ newThird, mutableAggBufferOffset, inputAggBufferOffset)
316315}
317316
318317/**
@@ -355,7 +354,7 @@ case class ThetaUnionAgg(
355354
356355 // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
357356
358- lazy val lgNomEntries : Int = {
357+ private lazy val lgNomEntries : Int = {
359358 val lgNomEntriesInput = right.eval().asInstanceOf [Int ]
360359 ThetaSketchUtils .checkLgNomLongs(lgNomEntriesInput, prettyName)
361360 lgNomEntriesInput
0 commit comments