Skip to content

Add support for DataFrame sum operation with tests #1148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 26, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ public fun AnyRow.rowSumOf(type: KType, skipNaN: Boolean = skipNaNDefault): Numb
// endregion

// region DataFrame

@Refine
@Interpretable("Sum0")
public fun <T> DataFrame<T>.sum(skipNaN: Boolean = skipNaNDefault): DataRow<T> =
sumFor(skipNaN, primitiveOrMixedNumberColumns())

@Refine
@Interpretable("Sum1")
public fun <T, C : Number?> DataFrame<T>.sumFor(
skipNaN: Boolean = skipNaNDefault,
columns: ColumnsForAggregateSelector<T, C>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ import kotlin.reflect.full.withNullability
* @param Return The type of the resulting value. Can optionally be nullable.
* @see [invoke]
*/
@PublishedApi
internal class Aggregator<in Value : Any, out Return : Any?>(
val aggregationHandler: AggregatorAggregationHandler<Value, Return>,
val inputHandler: AggregatorInputHandler<Value, Return>,
val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
val name: String,
public class Aggregator<in Value : Any, out Return : Any?>(
public val aggregationHandler: AggregatorAggregationHandler<Value, Return>,
public val inputHandler: AggregatorInputHandler<Value, Return>,
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
public val name: String,
) : AggregatorInputHandler<Value, Return> by inputHandler,
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
Expand Down Expand Up @@ -96,7 +95,7 @@ internal class Aggregator<in Value : Any, out Return : Any?>(
internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
values: Sequence<Value?>,
valueType: ValueType,
) = aggregateSequence(values, valueType)
): Return = aggregateSequence(values, valueType)

/**
* Performs aggregation on the given [values], taking [valueType] into account.
Expand All @@ -106,7 +105,7 @@ internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
internal fun <Value : Any, Return : Any?> Aggregator<Value, Return>.aggregate(
values: Sequence<Value?>,
valueType: KType,
) = aggregate(values, valueType.toValueType(needsFullConversion = false))
): Return = aggregate(values, valueType.toValueType(needsFullConversion = false))

/**
* If the specific [ValueType] of the input is not known, but you still want to call [aggregate],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import kotlin.reflect.KType
* It also provides information on which return type will be given, as [KType], given a [value type][ValueType].
* It can also provide the index of the result in the input values if it is a selecting aggregator.
*/
@PublishedApi
internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {

/**
* Base function of [Aggregator].
Expand All @@ -23,13 +22,13 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
*
* When the exact [valueType] is unknown, use [calculateValueType] or [aggregateCalculatingValueType].
*/
fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return
public fun aggregateSequence(values: Sequence<Value?>, valueType: ValueType): Return

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
*/
fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
Expand All @@ -43,7 +42,7 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
* @param emptyInput If `true`, the input values are considered empty. This often affects the return type.
* @return The return type of [aggregateSequence] as [KType].
*/
fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType
public fun calculateReturnType(valueType: KType, emptyInput: Boolean): KType

/**
* Function that can give the index of the aggregation result in the input [values], if it applies.
Expand All @@ -54,5 +53,5 @@ internal interface AggregatorAggregationHandler<in Value : Any, out Return : Any
*
* Defaults to `-1`.
*/
fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int
public fun indexOfAggregationResultSingleSequence(values: Sequence<Value?>, valueType: ValueType): Int
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
* the [init] function of each [AggregatorAggregationHandlers][AggregatorAggregationHandler] is called,
* which allows the handler to refer to [Aggregator] instance via [aggregator].
*/
internal interface AggregatorHandler<in Value : Any, out Return : Any?> {
public interface AggregatorHandler<in Value : Any, out Return : Any?> {

/**
* Reference to the aggregator instance.
*
* Can only be used once [init] has run.
*/
var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>?
public var aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>?

fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) {
public fun init(aggregator: Aggregator<@UnsafeVariance Value, @UnsafeVariance Return>) {
this.aggregator = aggregator
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ import kotlin.reflect.KType
* It can also calculate a specific [value type][ValueType] from the input values or input types
* if the (specific) type is not known.
*/
internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {
public interface AggregatorInputHandler<in Value : Any, out Return : Any?> : AggregatorHandler<Value, Return> {

/**
* If the specific [ValueType] of the input is not known, but you still want to call [aggregate],
* this function can be called to calculate it by combining the set of known [valueTypes].
*/
fun calculateValueType(valueTypes: Set<KType>): ValueType
public fun calculateValueType(valueTypes: Set<KType>): ValueType

/**
* WARNING: HEAVY!
Expand All @@ -23,7 +23,7 @@ internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : A
* this function can be called to calculate it by getting the types of [values] at runtime.
* This is heavy because it uses reflection on each value.
*/
fun calculateValueType(values: Sequence<Value?>): ValueType
public fun calculateValueType(values: Sequence<Value?>): ValueType

/**
* Preprocesses the input values before aggregation.
Expand All @@ -32,7 +32,7 @@ internal interface AggregatorInputHandler<in Value : Any, out Return : Any?> : A
*
* @return A pair of the preprocessed values and the (potentially new) type of the values.
*/
fun preprocessAggregation(
public fun preprocessAggregation(
values: Sequence<Value?>,
valueType: ValueType,
): Pair<Sequence<@UnsafeVariance Value?>, KType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import kotlin.reflect.KType
* [AggregatorAggregationHandler].
* It can also calculate the return type of the aggregation given all input column types.
*/
internal interface AggregatorMultipleColumnsHandler<in Value : Any, out Return : Any?> :
public interface AggregatorMultipleColumnsHandler<in Value : Any, out Return : Any?> :
AggregatorHandler<Value, Return> {

/**
* Aggregates the data in the multiple given columns and computes a single resulting value.
* Calls [Aggregator.aggregateSequence] or [Aggregator.aggregateSingleColumn].
*/
fun aggregateMultipleColumns(columns: Sequence<DataColumn<Value?>>): Return
public fun aggregateMultipleColumns(columns: Sequence<DataColumn<Value?>>): Return

/**
* Function that can give the return type of [aggregateMultipleColumns], given types of the columns.
Expand All @@ -26,5 +26,5 @@ internal interface AggregatorMultipleColumnsHandler<in Value : Any, out Return :
* @param colTypes The types of the input columns.
* @param colsEmpty If `true`, all the input columns are considered empty. This often affects the return type.
*/
fun calculateReturnTypeMultipleColumns(colTypes: Set<KType>, colsEmpty: Boolean): KType
public fun calculateReturnTypeMultipleColumns(colTypes: Set<KType>, colsEmpty: Boolean): KType
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,20 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
* Aggregators are cached by their parameter value.
* @see AggregatorOptionSwitch2
*/
@PublishedApi
internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : Any?>(
val name: String,
val getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
public class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : Any?>(
public val name: String,
public val getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
) {

private val cache: MutableMap<Param1, Aggregator<Value, Return>> = mutableMapOf()

operator fun invoke(param1: Param1): Aggregator<Value, Return> =
public operator fun invoke(param1: Param1): Aggregator<Value, @UnsafeVariance Return> =
cache.getOrPut(param1) {
getAggregator(param1).create(name)
}

@Suppress("FunctionName")
companion object {
public companion object {

/**
* Creates [AggregatorOptionSwitch1].
Expand All @@ -31,9 +30,10 @@ internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : A
* MyAggregator.Factory(param1)
* }
*/
fun <Param1, Value : Any, Return : Any?> Factory(
public fun <Param1, Value : Any, Return : Any?> Factory(
getAggregator: (param1: Param1) -> AggregatorProvider<Value, Return>,
) = Provider { name -> AggregatorOptionSwitch1(name, getAggregator) }
): Provider<AggregatorOptionSwitch1<Param1, Value, Return>> =
Provider { name -> AggregatorOptionSwitch1(name, getAggregator) }
}
}

Expand All @@ -43,21 +43,20 @@ internal class AggregatorOptionSwitch1<in Param1, in Value : Any, out Return : A
* Aggregators are cached by their parameter values.
* @see AggregatorOptionSwitch1
*/
@PublishedApi
internal class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out Return : Any?>(
val name: String,
val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
public class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out Return : Any?>(
public val name: String,
public val getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
) {

private val cache: MutableMap<Pair<Param1, Param2>, Aggregator<Value, Return>> = mutableMapOf()

operator fun invoke(param1: Param1, param2: Param2): Aggregator<Value, Return> =
public operator fun invoke(param1: Param1, param2: Param2): Aggregator<Value, @UnsafeVariance Return> =
cache.getOrPut(param1 to param2) {
getAggregator(param1, param2).create(name)
}

@Suppress("FunctionName")
companion object {
public companion object {

/**
* Creates [AggregatorOptionSwitch2].
Expand All @@ -68,7 +67,7 @@ internal class AggregatorOptionSwitch2<in Param1, in Param2, in Value : Any, out
* MyAggregator.Factory(param1, param2)
* }
*/
fun <Param1, Param2, Value : Any, Return : Any?> Factory(
internal fun <Param1, Param2, Value : Any, Return : Any?> Factory(
getAggregator: (param1: Param1, param2: Param2) -> AggregatorProvider<Value, Return>,
) = Provider { name -> AggregatorOptionSwitch2(name, getAggregator) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import kotlin.reflect.KProperty
* val myNamedValue by MyFactory
* ```
*/
internal fun interface Provider<out T> {
public fun interface Provider<out T> {

fun create(name: String): T
public fun create(name: String): T
}

internal operator fun <T> Provider<T>.getValue(obj: Any?, property: KProperty<*>): T = create(property.name)
Expand All @@ -25,4 +25,5 @@ internal operator fun <T> Provider<T>.getValue(obj: Any?, property: KProperty<*>
* val myAggregator by MyAggregator.Factory
* ```
*/
internal fun interface AggregatorProvider<in Value : Any, out Return : Any?> : Provider<Aggregator<Value, Return>>
public fun interface AggregatorProvider<in Value : Any, out Return : Any?> :
Provider<Aggregator<Value, @UnsafeVariance Return>>
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import org.jetbrains.kotlinx.dataframe.math.stdTypeConversion
import org.jetbrains.kotlinx.dataframe.math.sum
import org.jetbrains.kotlinx.dataframe.math.sumTypeConversion

@PublishedApi
internal object Aggregators {
public object Aggregators {

// TODO these might need some small refactoring

Expand Down Expand Up @@ -112,7 +111,7 @@ internal object Aggregators {

// T: Comparable<T> -> T?
// T : Comparable<T & Any>? -> T?
fun <T : Comparable<T & Any>?> min(skipNaN: Boolean): Aggregator<T & Any, T?> = min.invoke(skipNaN).cast2()
public fun <T : Comparable<T & Any>?> min(skipNaN: Boolean): Aggregator<T & Any, T?> = min.invoke(skipNaN).cast2()

private val min by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
Expand All @@ -124,7 +123,7 @@ internal object Aggregators {

// T: Comparable<T> -> T?
// T : Comparable<T & Any>? -> T?
fun <T : Comparable<T & Any>?> max(skipNaN: Boolean): Aggregator<T & Any, T?> = max.invoke(skipNaN).cast2()
public fun <T : Comparable<T & Any>?> max(skipNaN: Boolean): Aggregator<T & Any, T?> = max.invoke(skipNaN).cast2()

private val max by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
Expand All @@ -135,36 +134,41 @@ internal object Aggregators {
}

// T: Number? -> Double
val std by withTwoOptions { skipNaN: Boolean, ddof: Int ->
public val std: AggregatorOptionSwitch2<Boolean, Int, Number, Double> by withTwoOptions {
skipNaN: Boolean,
ddof: Int,
->
flattenReducingForNumbers(stdTypeConversion) { type ->
std(type, skipNaN, ddof)
}
}

// step one: T: Number? -> Double
// step two: Double -> Double
val mean by withOneOption { skipNaN: Boolean ->
public val mean: AggregatorOptionSwitch1<Boolean, Number, Double> by withOneOption { skipNaN: Boolean ->
twoStepReducingForNumbers(meanTypeConversion) { type ->
mean(type, skipNaN)
}
}

// T : primitive Number? -> Double?
// T : Comparable<T & Any>? -> T?
fun <T> percentileCommon(
public fun <T> percentileCommon(
percentile: Double,
skipNaN: Boolean,
): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
this.percentile.invoke(percentile, skipNaN).cast2()

// T : Comparable<T & Any>? -> T?
fun <T> percentileComparables(percentile: Double): Aggregator<T & Any, T?>
public fun <T> percentileComparables(
percentile: Double,
): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
percentileCommon<T>(percentile, skipNaNDefault).cast2()

// T : primitive Number? -> Double?
fun <T> percentileNumbers(
public fun <T> percentileNumbers(
percentile: Double,
skipNaN: Boolean,
): Aggregator<T & Any, Double?>
Expand All @@ -182,17 +186,17 @@ internal object Aggregators {

// T : primitive Number? -> Double?
// T : Comparable<T & Any>? -> T?
fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
public fun <T> medianCommon(skipNaN: Boolean): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
median.invoke(skipNaN).cast2()

// T : Comparable<T & Any>? -> T?
fun <T> medianComparables(): Aggregator<T & Any, T?>
public fun <T> medianComparables(): Aggregator<T & Any, T?>
where T : Comparable<T & Any>? =
medianCommon<T>(skipNaNDefault).cast2()

// T : primitive Number? -> Double?
fun <T> medianNumbers(
public fun <T> medianNumbers(
skipNaN: Boolean,
): Aggregator<T & Any, Double?>
where T : Comparable<T & Any>?, T : Number? =
Expand All @@ -211,7 +215,7 @@ internal object Aggregators {
// Byte -> Int
// Short -> Int
// Nothing -> Double
val sum by withOneOption { skipNaN: Boolean ->
public val sum: AggregatorOptionSwitch1<Boolean, Number, Number> by withOneOption { skipNaN: Boolean ->
twoStepReducingForNumbers(sumTypeConversion) { type ->
sum(type, skipNaN)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ import kotlin.reflect.KType
* for the values to become the correct value type. If `false`, the values are already the right type,
* or a simple cast will suffice.
*/
internal data class ValueType(val kType: KType, val needsFullConversion: Boolean = false)
public data class ValueType(val kType: KType, val needsFullConversion: Boolean = false)

internal fun KType.toValueType(needsFullConversion: Boolean = false): ValueType = ValueType(this, needsFullConversion)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import kotlin.reflect.KType
* If not supplied, the handler of the first step is reused.
* @see [FlatteningMultipleColumnsHandler]
*/
internal class TwoStepMultipleColumnsHandler<in Value : Any, out Return : Any?>(
internal class TwoStepMultipleColumnsHandler<in Value : Any, Return : Any?>(
stepTwoAggregationHandler: AggregatorAggregationHandler<Return & Any, Return>? = null,
stepTwoInputHandler: AggregatorInputHandler<Return & Any, Return>? = null,
) : AggregatorMultipleColumnsHandler<Value, Return> {
Expand Down
Loading