-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-53848] Add ability to support Alpha family in Theta Aggregates #52551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1198,6 +1198,17 @@ object functions { | |
def theta_sketch_agg(e: Column, lgNomEntries: Column): Column = | ||
Column.fn("theta_sketch_agg", e, lgNomEntries) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column and configured with the `lgNomEntries` nominal | ||
* entries and `family`. | ||
* | ||
* @group agg_funcs | ||
* @since 4.1.0 | ||
*/ | ||
def theta_sketch_agg(e: Column, lgNomEntries: Column, family: Column): Column = | ||
Column.fn("theta_sketch_agg", e, lgNomEntries, family) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column and configured with the `lgNomEntries` nominal | ||
|
@@ -1242,6 +1253,47 @@ object functions { | |
def theta_sketch_agg(columnName: String): Column = | ||
theta_sketch_agg(Column(columnName)) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column, configured with `lgNomEntries` and `family`. | ||
* | ||
* @group agg_funcs | ||
* @since 4.1.0 | ||
*/ | ||
def theta_sketch_agg(e: Column, lgNomEntries: Int, family: String): Column = | ||
Column.fn("theta_sketch_agg", e, lit(lgNomEntries), lit(family)) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column, configured with `lgNomEntries` and `family`. | ||
* | ||
* @group agg_funcs | ||
* @since 4.1.0 | ||
*/ | ||
def theta_sketch_agg(columnName: String, lgNomEntries: Int, family: String): Column = | ||
theta_sketch_agg(Column(columnName), lgNomEntries, family) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column, configured with the specified `family` and default | ||
* lgNomEntries. | ||
* | ||
* @group agg_funcs | ||
* @since 4.1.0 | ||
*/ | ||
def theta_sketch_agg(e: Column, family: String): Column = | ||
theta_sketch_agg(e, 12, family) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The hardcoded 12 matches the Catalyst default, so behaviorally it’s fine. Still, duplicating an internal default in the public layer is a bit awkward .We could consider making this explicit (functions above preferred) or defining a local constant for clarity. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I gave some thoughts on this before adding it here. Ideally the resolution to use 12 for logNomEntries should be in But, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd recommend we remove these two |
||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch | ||
* built with the values in the input column, configured with specified `family`. | ||
* | ||
* @group agg_funcs | ||
* @since 4.1.0 | ||
*/ | ||
def theta_sketch_agg(columnName: String, family: String): Column = | ||
theta_sketch_agg(columnName, 12, family) | ||
|
||
/** | ||
* Aggregate function: returns the compact binary representation of the Datasketches | ||
* ThetaSketch, generated by the union of Datasketches ThetaSketch instances in the input column | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,14 @@ | |
|
||
package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
||
import org.apache.datasketches.common.Family | ||
import org.apache.datasketches.memory.Memory | ||
import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder} | ||
|
||
import org.apache.spark.SparkUnsupportedOperationException | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} | ||
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate | ||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike} | ||
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} | ||
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils} | ||
import org.apache.spark.sql.errors.QueryExecutionErrors | ||
import org.apache.spark.sql.internal.types.StringTypeWithCollation | ||
|
@@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState { | |
* | ||
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. | ||
* | ||
* @param left | ||
* @param first | ||
* child expression against which unique counting will occur | ||
* @param right | ||
* @param second | ||
* the log-base-2 of nomEntries decides the number of buckets for the sketch | ||
* @param third | ||
* the family of the sketch (QUICKSELECT or ALPHA) | ||
* @param mutableAggBufferOffset | ||
* offset for mutable aggregation buffer | ||
* @param inputAggBufferOffset | ||
|
@@ -71,46 +73,67 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState { | |
// scalastyle:off line.size.limit | ||
@ExpressionDescription( | ||
usage = """ | ||
_FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation. | ||
_FUNC_(expr, lgNomEntries, family) - Returns the ThetaSketch compact binary representation. | ||
`lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding | ||
the number buckets or slots for the ThetaSketch. """, | ||
the number buckets or slots for the ThetaSketch. | ||
`family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to | ||
'QUICKSELECT').""", | ||
examples = """ | ||
Examples: | ||
> SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col); | ||
3 | ||
> SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col); | ||
3 | ||
> SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col); | ||
3 | ||
""", | ||
group = "agg_funcs", | ||
since = "4.1.0") | ||
// scalastyle:on line.size.limit | ||
case class ThetaSketchAgg( | ||
left: Expression, | ||
right: Expression, | ||
override val mutableAggBufferOffset: Int, | ||
override val inputAggBufferOffset: Int) | ||
first: Expression, | ||
second: Expression, | ||
third: Expression, | ||
override val mutableAggBufferOffset: Int, | ||
override val inputAggBufferOffset: Int) | ||
extends TypedImperativeAggregate[ThetaSketchState] | ||
with BinaryLike[Expression] | ||
karuppayya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with TernaryLike[Expression] | ||
with ExpectsInputTypes { | ||
|
||
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. | ||
|
||
lazy val lgNomEntries: Int = { | ||
val lgNomEntriesInput = right.eval().asInstanceOf[Int] | ||
private lazy val lgNomEntries: Int = { | ||
val lgNomEntriesInput = second.eval().asInstanceOf[Int] | ||
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName) | ||
lgNomEntriesInput | ||
} | ||
|
||
// Constructors | ||
private lazy val family: Family = { | ||
val familyName = third.eval().asInstanceOf[UTF8String] | ||
ThetaSketchUtils.parseFamily(familyName.toString, prettyName) | ||
} | ||
|
||
// Constructors | ||
def this(child: Expression) = { | ||
this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0) | ||
this(child, | ||
Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), | ||
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)), | ||
0, 0) | ||
} | ||
|
||
def this(child: Expression, lgNomEntries: Expression) = { | ||
this(child, lgNomEntries, 0, 0) | ||
this(child, | ||
lgNomEntries, | ||
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)), | ||
0, 0) | ||
} | ||
|
||
def this(child: Expression, lgNomEntries: Expression, family: Expression) = { | ||
this(child, lgNomEntries, family, 0, 0) | ||
} | ||
|
||
def this(child: Expression, lgNomEntries: Int) = { | ||
this(child, Literal(lgNomEntries), 0, 0) | ||
this(child, Literal(lgNomEntries)) | ||
} | ||
|
||
// Copy constructors required by ImperativeAggregate | ||
|
@@ -121,16 +144,11 @@ case class ThetaSketchAgg( | |
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg = | ||
copy(inputAggBufferOffset = newInputAggBufferOffset) | ||
|
||
override protected def withNewChildrenInternal( | ||
newLeft: Expression, | ||
newRight: Expression): ThetaSketchAgg = | ||
copy(left = newLeft, right = newRight) | ||
|
||
// Overrides for TypedImperativeAggregate | ||
|
||
override def prettyName: String = "theta_sketch_agg" | ||
|
||
override def inputTypes: Seq[AbstractDataType] = | ||
override def inputTypes: Seq[AbstractDataType] = { | ||
Seq( | ||
TypeCollection( | ||
ArrayType(IntegerType), | ||
|
@@ -141,21 +159,24 @@ case class ThetaSketchAgg( | |
IntegerType, | ||
LongType, | ||
StringTypeWithCollation(supportsTrimCollation = true)), | ||
IntegerType) | ||
IntegerType, | ||
StringType) | ||
} | ||
|
||
override def dataType: DataType = BinaryType | ||
|
||
override def nullable: Boolean = false | ||
|
||
/** | ||
* Instantiate an UpdateSketch instance using the lgNomEntries param. | ||
* Instantiate an UpdateSketch instance using the lgNomEntries and family params. | ||
* | ||
* @return | ||
* an UpdateSketch instance wrapped with UpdatableSketchBuffer | ||
*/ | ||
override def createAggregationBuffer(): ThetaSketchState = { | ||
val builder = new UpdateSketchBuilder | ||
builder.setLogNominalEntries(lgNomEntries) | ||
builder.setFamily(family) | ||
UpdatableSketchBuffer(builder.build) | ||
} | ||
|
||
|
@@ -176,7 +197,7 @@ case class ThetaSketchAgg( | |
*/ | ||
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = { | ||
// Return early for null values. | ||
val v = left.eval(input) | ||
val v = first.eval(input) | ||
if (v == null) return updateBuffer | ||
|
||
// Initialized buffer should be UpdatableSketchBuffer, else error out. | ||
|
@@ -186,7 +207,7 @@ case class ThetaSketchAgg( | |
} | ||
|
||
// Handle the different data types for sketch updates. | ||
left.dataType match { | ||
first.dataType match { | ||
case ArrayType(IntegerType, _) => | ||
val arr = v.asInstanceOf[ArrayData].toIntArray() | ||
sketch.update(arr) | ||
|
@@ -213,7 +234,7 @@ case class ThetaSketchAgg( | |
case _ => | ||
throw new SparkUnsupportedOperationException( | ||
errorClass = "_LEGACY_ERROR_TEMP_3121", | ||
messageParameters = Map("dataType" -> left.dataType.toString)) | ||
messageParameters = Map("dataType" -> first.dataType.toString)) | ||
} | ||
|
||
UpdatableSketchBuffer(sketch) | ||
|
@@ -289,6 +310,10 @@ case class ThetaSketchAgg( | |
this.createAggregationBuffer() | ||
} | ||
} | ||
|
||
override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, | ||
newThird: Expression): Expression = copy(newFirst, newSecond, | ||
newThird, mutableAggBufferOffset, inputAggBufferOffset) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit, the withNewChildrenInternal override doesn’t need to explicitly pass mutableAggBufferOffset or inputAggBufferOffset. This is redundant. Return type should also be can have something like this:
|
||
} | ||
|
||
/** | ||
|
@@ -331,7 +356,7 @@ case class ThetaUnionAgg( | |
|
||
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. | ||
|
||
lazy val lgNomEntries: Int = { | ||
private lazy val lgNomEntries: Int = { | ||
val lgNomEntriesInput = right.eval().asInstanceOf[Int] | ||
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName) | ||
lgNomEntriesInput | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of consistency, we can consider changing this function to look like this:
Similarly in the other builtin.py