Skip to content

Commit

Permalink
Fix min and max aggregates for floating points
Browse files Browse the repository at this point in the history
Summary:
This change ensures that extreme floating point values are handled
correctly, including INF, -INF, and NaN, where NaN is considered
greater than INF. Additionally, it correctly sets the initial group
values for floating point types (-INF for max() and NaN for min()),
which is relevant when the inputs consist solely of these extreme
values.

Differential Revision: D57801191
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 24, 2024
1 parent 1a0d26a commit 44b02bf
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 28 deletions.
10 changes: 10 additions & 0 deletions velox/docs/functions/presto/aggregate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,34 @@ General Aggregate Functions
Returns the maximum value of all input values.
``x`` must not contain nulls when it is complex type.
``x`` must be an orderable type.
Nulls are ignored if there are any non-null inputs.
For REAL and DOUBLE types, NaN is considered greater than Infinity.

.. function:: max(x, n) -> array<[same as x]>
:noindex:

Returns ``n`` largest values of all input values of ``x``.
``n`` must be a positive integer and not exceed 10'000.
Currently not supported for ARRAY, MAP, and ROW input types.
Nulls are not included in the output array.
For REAL and DOUBLE types, NaN is considered greater than Infinity.

.. function:: min(x) -> [same as x]

Returns the minimum value of all input values.
``x`` must not contain nulls when it is complex type.
``x`` must be an orderable type.
Nulls are ignored if there are any non-null inputs.
For REAL and DOUBLE types, NaN is considered greater than Infinity.

.. function:: min(x, n) -> array<[same as x]>
:noindex:

Returns ``n`` smallest values of all input values of ``x``.
``n`` must be a positive integer and not exceed 10'000.
Currently not supported for ARRAY, MAP, and ROW input types.
Nulls are not included in output array.
For REAL and DOUBLE types, NaN is considered greater than Infinity.

.. function:: multimap_agg(K key, V value) -> map(K,array(V))

Expand Down
97 changes: 69 additions & 28 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "velox/functions/lib/aggregates/SingleValueAccumulator.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/functions/prestosql/aggregates/Compare.h"
#include "velox/type/FloatingPointUtil.h"

using namespace facebook::velox::functions::aggregate;

Expand Down Expand Up @@ -119,15 +120,7 @@ class MaxAggregate : public MinMaxAggregate<T> {
return;
}
BaseAggregate::template updateGroups<true, T>(
groups,
rows,
args[0],
[](T& result, T value) {
if (result < value) {
result = value;
}
},
mayPushdown);
groups, rows, args[0], updateGroup, mayPushdown);
}

void addIntermediateResults(
Expand All @@ -147,11 +140,7 @@ class MaxAggregate : public MinMaxAggregate<T> {
group,
rows,
args[0],
[](T& result, T value) {
if (result < value) {
result = value;
}
},
updateGroup,
[](T& result, T value, int /* unused */) { result = value; },
mayPushdown,
kInitialValue_);
Expand All @@ -175,13 +164,34 @@ class MaxAggregate : public MinMaxAggregate<T> {
}
}

static inline void updateGroup(T& result, T value) {
if constexpr (std::is_floating_point_v<T>) {
if (util::floating_point::NaNAwareLessThan<T>{}(result, value)) {
result = value;
}
} else {
if (result < value) {
result = value;
}
}
}

private:
static const T kInitialValue_;
};

template <typename T>
const T MaxAggregate<T>::kInitialValue_ = MinMaxTrait<T>::lowest();

// Negative INF is the smallest value of floating point type.
template <>
const float MaxAggregate<float>::kInitialValue_ =
-1 * MinMaxTrait<float>::infinity();

template <>
const double MaxAggregate<double>::kInitialValue_ =
-1 * MinMaxTrait<double>::infinity();

template <typename T>
class MinAggregate : public MinMaxAggregate<T> {
using BaseAggregate = SimpleNumericAggregate<T, T, T>;
Expand All @@ -205,15 +215,7 @@ class MinAggregate : public MinMaxAggregate<T> {
return;
}
BaseAggregate::template updateGroups<true, T>(
groups,
rows,
args[0],
[](T& result, T value) {
if (result > value) {
result = value;
}
},
mayPushdown);
groups, rows, args[0], updateGroup, mayPushdown);
}

void addIntermediateResults(
Expand All @@ -233,7 +235,7 @@ class MinAggregate : public MinMaxAggregate<T> {
group,
rows,
args[0],
[](T& result, T value) { result = result < value ? result : value; },
updateGroup,
[](T& result, T value, int /* unused */) { result = value; },
mayPushdown,
kInitialValue_);
Expand All @@ -248,6 +250,18 @@ class MinAggregate : public MinMaxAggregate<T> {
}

protected:
static inline void updateGroup(T& result, T value) {
if constexpr (std::is_floating_point_v<T>) {
if (util::floating_point::NaNAwareGreaterThan<T>{}(result, value)) {
result = value;
}
} else {
if (result > value) {
result = value;
}
}
}

void initializeNewGroupsInternal(
char** groups,
folly::Range<const vector_size_t*> indices) override {
Expand All @@ -264,6 +278,15 @@ class MinAggregate : public MinMaxAggregate<T> {
template <typename T>
const T MinAggregate<T>::kInitialValue_ = MinMaxTrait<T>::max();

// In velox, NaN is considered larger than infinity for floating point types.
template <>
const float MinAggregate<float>::kInitialValue_ =
MinMaxTrait<float>::quiet_NaN();

template <>
const double MinAggregate<double>::kInitialValue_ =
MinMaxTrait<double>::quiet_NaN();

class NonNumericMinMaxAggregateBase : public exec::Aggregate {
public:
explicit NonNumericMinMaxAggregateBase(
Expand Down Expand Up @@ -878,17 +901,35 @@ class MinMaxNAggregateBase : public exec::Aggregate {
};

template <typename T>
class MinNAggregate : public MinMaxNAggregateBase<T, std::less<T>> {
struct LessThanComparator : public std::less<T> {};
template <>
struct LessThanComparator<float>
: public util::floating_point::NaNAwareLessThan<float> {};
template <>
struct LessThanComparator<double>
: public util::floating_point::NaNAwareLessThan<double> {};

template <typename T>
struct GreaterThanComparator : public std::less<T> {};
template <>
struct GreaterThanComparator<float>
: public util::floating_point::NaNAwareGreaterThan<float> {};
template <>
struct GreaterThanComparator<double>
: public util::floating_point::NaNAwareGreaterThan<double> {};

template <typename T>
class MinNAggregate : public MinMaxNAggregateBase<T, LessThanComparator<T>> {
public:
explicit MinNAggregate(const TypePtr& resultType)
: MinMaxNAggregateBase<T, std::less<T>>(resultType) {}
: MinMaxNAggregateBase<T, LessThanComparator<T>>(resultType) {}
};

template <typename T>
class MaxNAggregate : public MinMaxNAggregateBase<T, std::greater<T>> {
class MaxNAggregate : public MinMaxNAggregateBase<T, GreaterThanComparator<T>> {
public:
explicit MaxNAggregate(const TypePtr& resultType)
: MinMaxNAggregateBase<T, std::greater<T>>(resultType) {}
: MinMaxNAggregateBase<T, GreaterThanComparator<T>>(resultType) {}
};

template <
Expand Down
Loading

0 comments on commit 44b02bf

Please sign in to comment.