Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5605,6 +5605,12 @@
],
"sqlState" : "428EK"
},
"THETA_INVALID_FAMILY" : {
"message" : [
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>, but got: <value>."
],
"sqlState" : "22546"
},
"THETA_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid Theta sketch buffers are supported as inputs (such as those produced by the `theta_sketch_agg` function)."
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4337,12 +4337,17 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the sake of consistency, we can consider changing this function to look like this:

def theta_sketch_agg(
    col: "ColumnOrName",
    lgNomEntries: Optional[Union[int, Column]] = None,
    family: Optional["ColumnOrName"] = None,
) -> Column:
    fn = "theta_sketch_agg"
    _lgNomEntries = lit(12) if lgNomEntries is None else lit(lgNomEntries)
    _family = lit("QUICKSELECT") if family is None else _to_col(family)

    return _invoke_function_over_columns(fn, col, _lgNomEntries, _family)

Similarly in the other builtin.py

) -> Column:
fn = "theta_sketch_agg"
if lgNomEntries is None:
if lgNomEntries is None and family is None:
return _invoke_function_over_columns(fn, col)
else:
elif family is None:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
else:
if lgNomEntries is None:
lgNomEntries = 12 # default value
return _invoke_function_over_columns(fn, col, lit(lgNomEntries), lit(family))


theta_sketch_agg.__doc__ = pysparkfuncs.theta_sketch_agg.__doc__
Expand Down
51 changes: 33 additions & 18 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25941,10 +25941,12 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
) -> Column:
"""
Aggregate function: returns the compact binary representation of the Datasketches
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries.
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries
and the specified sketch family.

.. versionadded:: 4.1.0

Expand All @@ -25954,6 +25956,8 @@ def theta_sketch_agg(
lgNomEntries : :class:`~pyspark.sql.Column` or int, optional
The log-base-2 of nominal entries, where nominal entries is the size of the sketch
(must be between 4 and 26, defaults to 12)
family : str, optional
The sketch family: 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').

Returns
-------
Expand All @@ -25974,24 +25978,35 @@ def theta_sketch_agg(
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15, "ALPHA"))).show()
+---------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, ALPHA))|
+---------------------------------------------------------+
| 3|
+---------------------------------------------------------+
"""
fn = "theta_sketch_agg"
if lgNomEntries is None:
if lgNomEntries is None and family is None:
return _invoke_function_over_columns(fn, col)
else:
elif family is None:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
else:
if lgNomEntries is None:
lgNomEntries = 12 # default value
return _invoke_function_over_columns(fn, col, lit(lgNomEntries), lit(family))


@_try_remote_functions
Expand Down Expand Up @@ -26118,11 +26133,11 @@ def theta_sketch_estimate(col: "ColumnOrName") -> Column:
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+
"""

fn = "theta_sketch_estimate"
Expand Down
52 changes: 52 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,17 @@ object functions {
def theta_sketch_agg(e: Column, lgNomEntries: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
* entries and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Column, family: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
Expand Down Expand Up @@ -1242,6 +1253,47 @@ object functions {
def theta_sketch_agg(columnName: String): Column =
theta_sketch_agg(Column(columnName))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Int, family: String): Column =
Column.fn("theta_sketch_agg", e, lit(lgNomEntries), lit(family))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(columnName: String, lgNomEntries: Int, family: String): Column =
theta_sketch_agg(Column(columnName), lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with the specified `family` and default
* lgNomEntries.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, family: String): Column =
theta_sketch_agg(e, 12, family)
Copy link
Contributor

@cboumalh cboumalh Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hardcoded 12 matches the Catalyst default, so behaviorally it’s fine. Still, duplicating an internal default in the public layer is a bit awkward .We could consider making this explicit (functions above preferred) or defining a local constant for clarity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I gave some thoughts on this before adding it here. Ideally the resolution to use 12 for logNomEntries should be in ThetaSketchAgg.
ie to say if I just pass the column and family, the logNomEntries gets defaulted in the ThetaSketchAgg.

But,
FunctionRegistry selects the constructor based on the number of expressions. I cannot add a second constructor with same signature as this .
So i decided to default in functions.scala. I also didnt see similar pattern in a different function, so i am not very sure either.
Referencing a local constant in this class also seemed a bit weird, since that would be very specific.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend we remove these two theta_sketch_agg(columnName: String, family: String) and theta_sketch_agg(e: Column, family: String) completely. If we later add another argument that is a of String type, we'll have function overloading ambiguity. Therefore, If a user wants to use a specific family, they must also pass in the lgNomEntries explicitly. It will keep this file cleaner and avoid magic numbers. This will mean we also need to fix both builtin.py files to avoid the same phenomenon. I'm open to seeing what others have to say about this too.


/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with specified `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(columnName: String, family: String): Column =
theta_sketch_agg(columnName, 12, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches
* ThetaSketch, generated by the union of Datasketches ThetaSketch instances in the input column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.datasketches.common.Family
import org.apache.datasketches.memory.Memory
import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder}

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.types.StringTypeWithCollation
Expand Down Expand Up @@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
*
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information.
*
* @param left
* @param first
* child expression against which unique counting will occur
* @param right
* @param second
* the log-base-2 of nomEntries decides the number of buckets for the sketch
* @param third
* the family of the sketch (QUICKSELECT or ALPHA)
* @param mutableAggBufferOffset
* offset for mutable aggregation buffer
* @param inputAggBufferOffset
Expand All @@ -71,46 +73,67 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation.
_FUNC_(expr, lgNomEntries, family) - Returns the ThetaSketch compact binary representation.
`lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding
the number buckets or slots for the ThetaSketch. """,
the number buckets or slots for the ThetaSketch.
`family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to
'QUICKSELECT').""",
examples = """
Examples:
> SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
""",
group = "agg_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class ThetaSketchAgg(
left: Expression,
right: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
first: Expression,
second: Expression,
third: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[ThetaSketchState]
with BinaryLike[Expression]
with TernaryLike[Expression]
with ExpectsInputTypes {

// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.

lazy val lgNomEntries: Int = {
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
private lazy val lgNomEntries: Int = {
val lgNomEntriesInput = second.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
}

// Constructors
private lazy val family: Family = {
val familyName = third.eval().asInstanceOf[UTF8String]
ThetaSketchUtils.parseFamily(familyName.toString, prettyName)
}

// Constructors
def this(child: Expression) = {
this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0)
this(child,
Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS),
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression) = {
this(child, lgNomEntries, 0, 0)
this(child,
lgNomEntries,
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression, family: Expression) = {
this(child, lgNomEntries, family, 0, 0)
}

def this(child: Expression, lgNomEntries: Int) = {
this(child, Literal(lgNomEntries), 0, 0)
this(child, Literal(lgNomEntries))
}

// Copy constructors required by ImperativeAggregate
Expand All @@ -121,16 +144,11 @@ case class ThetaSketchAgg(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg =
copy(inputAggBufferOffset = newInputAggBufferOffset)

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): ThetaSketchAgg =
copy(left = newLeft, right = newRight)

// Overrides for TypedImperativeAggregate

override def prettyName: String = "theta_sketch_agg"

override def inputTypes: Seq[AbstractDataType] =
override def inputTypes: Seq[AbstractDataType] = {
Seq(
TypeCollection(
ArrayType(IntegerType),
Expand All @@ -141,21 +159,24 @@ case class ThetaSketchAgg(
IntegerType,
LongType,
StringTypeWithCollation(supportsTrimCollation = true)),
IntegerType)
IntegerType,
StringType)
}

override def dataType: DataType = BinaryType

override def nullable: Boolean = false

/**
* Instantiate an UpdateSketch instance using the lgNomEntries param.
* Instantiate an UpdateSketch instance using the lgNomEntries and family params.
*
* @return
* an UpdateSketch instance wrapped with UpdatableSketchBuffer
*/
override def createAggregationBuffer(): ThetaSketchState = {
val builder = new UpdateSketchBuilder
builder.setLogNominalEntries(lgNomEntries)
builder.setFamily(family)
UpdatableSketchBuffer(builder.build)
}

Expand All @@ -176,7 +197,7 @@ case class ThetaSketchAgg(
*/
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = {
// Return early for null values.
val v = left.eval(input)
val v = first.eval(input)
if (v == null) return updateBuffer

// Initialized buffer should be UpdatableSketchBuffer, else error out.
Expand All @@ -186,7 +207,7 @@ case class ThetaSketchAgg(
}

// Handle the different data types for sketch updates.
left.dataType match {
first.dataType match {
case ArrayType(IntegerType, _) =>
val arr = v.asInstanceOf[ArrayData].toIntArray()
sketch.update(arr)
Expand All @@ -213,7 +234,7 @@ case class ThetaSketchAgg(
case _ =>
throw new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_3121",
messageParameters = Map("dataType" -> left.dataType.toString))
messageParameters = Map("dataType" -> first.dataType.toString))
}

UpdatableSketchBuffer(sketch)
Expand Down Expand Up @@ -289,6 +310,10 @@ case class ThetaSketchAgg(
this.createAggregationBuffer()
}
}

override protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression,
newThird: Expression): Expression = copy(newFirst, newSecond,
newThird, mutableAggBufferOffset, inputAggBufferOffset)
Copy link
Contributor

@cboumalh cboumalh Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit, the withNewChildrenInternal override doesn’t need to explicitly pass mutableAggBufferOffset or inputAggBufferOffset. This is redundant. Return type should also be ThetaSketchAgg. Lastly, can consider adding it above on line 146 to group it with the rest

can have something like this:

  override protected def withNewChildrenInternal(
      newFirst: Expression, newSecond: Expression, newThird: Expression): ThetaSketchAgg =
    copy(
      first = newFirst,
      second = newSecond,
      third = newThird)
}

}

/**
Expand Down Expand Up @@ -331,7 +356,7 @@ case class ThetaUnionAgg(

// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.

lazy val lgNomEntries: Int = {
private lazy val lgNomEntries: Int = {
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
Expand Down
Loading