diff --git a/velox/docs/functions/presto/aggregate.rst b/velox/docs/functions/presto/aggregate.rst index 7be84b469cc58..d36ce7f5b0175 100644 --- a/velox/docs/functions/presto/aggregate.rst +++ b/velox/docs/functions/presto/aggregate.rst @@ -117,24 +117,34 @@ General Aggregate Functions Returns the maximum value of all input values. ``x`` must not contain nulls when it is complex type. ``x`` must be an orderable type. + Nulls are ignored if there are any non-null inputs. + For REAL and DOUBLE types, NaN is considered greater than Infinity. .. function:: max(x, n) -> array<[same as x]> :noindex: Returns ``n`` largest values of all input values of ``x``. ``n`` must be a positive integer and not exceed 10'000. + Currently not supported for ARRAY, MAP, and ROW input types. + Nulls are not included in the output array. + For REAL and DOUBLE types, NaN is considered greater than Infinity. .. function:: min(x) -> [same as x] Returns the minimum value of all input values. ``x`` must not contain nulls when it is complex type. ``x`` must be an orderable type. + Nulls are ignored if there are any non-null inputs. + For REAL and DOUBLE types, NaN is considered greater than Infinity. .. function:: min(x, n) -> array<[same as x]> :noindex: Returns ``n`` smallest values of all input values of ``x``. ``n`` must be a positive integer and not exceed 10'000. + Currently not supported for ARRAY, MAP, and ROW input types. + Nulls are not included in output array. + For REAL and DOUBLE types, NaN is considered greater than Infinity. .. function:: multimap_agg(K key, V value) -> map(K,array(V)) diff --git a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp index b5b7a88903aad..ae4af7107da08 100644 --- a/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxAggregates.cpp @@ -22,6 +22,7 @@ #include "velox/functions/lib/aggregates/SingleValueAccumulator.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" #include "velox/functions/prestosql/aggregates/Compare.h" +#include "velox/type/FloatingPointUtil.h" using namespace facebook::velox::functions::aggregate; @@ -119,15 +120,7 @@ class MaxAggregate : public MinMaxAggregate { return; } BaseAggregate::template updateGroups( - groups, - rows, - args[0], - [](T& result, T value) { - if (result < value) { - result = value; - } - }, - mayPushdown); + groups, rows, args[0], updateGroup, mayPushdown); } void addIntermediateResults( @@ -147,11 +140,7 @@ class MaxAggregate : public MinMaxAggregate { group, rows, args[0], - [](T& result, T value) { - if (result < value) { - result = value; - } - }, + updateGroup, [](T& result, T value, int /* unused */) { result = value; }, mayPushdown, kInitialValue_); @@ -175,6 +164,18 @@ class MaxAggregate : public MinMaxAggregate { } } + static inline void updateGroup(T& result, T value) { + if constexpr (std::is_floating_point_v) { + if (util::floating_point::NaNAwareLessThan{}(result, value)) { + result = value; + } + } else { + if (result < value) { + result = value; + } + } + } + private: static const T kInitialValue_; }; @@ -182,6 +183,15 @@ class MaxAggregate : public MinMaxAggregate { template const T MaxAggregate::kInitialValue_ = MinMaxTrait::lowest(); +// Negative INF is the smallest value of floating point type. +template <> +const float MaxAggregate::kInitialValue_ = + -1 * MinMaxTrait::infinity(); + +template <> +const double MaxAggregate::kInitialValue_ = + -1 * MinMaxTrait::infinity(); + template class MinAggregate : public MinMaxAggregate { using BaseAggregate = SimpleNumericAggregate; @@ -205,15 +215,7 @@ class MinAggregate : public MinMaxAggregate { return; } BaseAggregate::template updateGroups( - groups, - rows, - args[0], - [](T& result, T value) { - if (result > value) { - result = value; - } - }, - mayPushdown); + groups, rows, args[0], updateGroup, mayPushdown); } void addIntermediateResults( @@ -233,7 +235,7 @@ class MinAggregate : public MinMaxAggregate { group, rows, args[0], - [](T& result, T value) { result = result < value ? result : value; }, + updateGroup, [](T& result, T value, int /* unused */) { result = value; }, mayPushdown, kInitialValue_); @@ -248,6 +250,18 @@ class MinAggregate : public MinMaxAggregate { } protected: + static inline void updateGroup(T& result, T value) { + if constexpr (std::is_floating_point_v) { + if (util::floating_point::NaNAwareGreaterThan{}(result, value)) { + result = value; + } + } else { + if (result > value) { + result = value; + } + } + } + void initializeNewGroupsInternal( char** groups, folly::Range indices) override { @@ -264,6 +278,15 @@ class MinAggregate : public MinMaxAggregate { template const T MinAggregate::kInitialValue_ = MinMaxTrait::max(); +// In velox, NaN is considered larger than infinity for floating point types. +template <> +const float MinAggregate::kInitialValue_ = + MinMaxTrait::quiet_NaN(); + +template <> +const double MinAggregate::kInitialValue_ = + MinMaxTrait::quiet_NaN(); + class NonNumericMinMaxAggregateBase : public exec::Aggregate { public: explicit NonNumericMinMaxAggregateBase( @@ -878,17 +901,35 @@ class MinMaxNAggregateBase : public exec::Aggregate { }; template -class MinNAggregate : public MinMaxNAggregateBase> { +struct LessThanComparator : public std::less {}; +template <> +struct LessThanComparator + : public util::floating_point::NaNAwareLessThan {}; +template <> +struct LessThanComparator + : public util::floating_point::NaNAwareLessThan {}; + +template +struct GreaterThanComparator : public std::less {}; +template <> +struct GreaterThanComparator + : public util::floating_point::NaNAwareGreaterThan {}; +template <> +struct GreaterThanComparator + : public util::floating_point::NaNAwareGreaterThan {}; + +template +class MinNAggregate : public MinMaxNAggregateBase> { public: explicit MinNAggregate(const TypePtr& resultType) - : MinMaxNAggregateBase>(resultType) {} + : MinMaxNAggregateBase>(resultType) {} }; template -class MaxNAggregate : public MinMaxNAggregateBase> { +class MaxNAggregate : public MinMaxNAggregateBase> { public: explicit MaxNAggregate(const TypePtr& resultType) - : MinMaxNAggregateBase>(resultType) {} + : MinMaxNAggregateBase>(resultType) {} }; template < diff --git a/velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp b/velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp index 7e1f352484a0c..f9eb5cfa74dde 100644 --- a/velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp @@ -104,6 +104,135 @@ class MinMaxTest : public functions::aggregate::test::AggregationTestBase { {agg(c1)}, fmt::format("SELECT {} FROM tmp WHERE c0 % 2 = 0", agg(c1))); } + + template + void testExtremeFloatValues() { + // Tests to ensure that extreme floating point values are handled correctly, + // including, INF, -INF, NaN. This validates that the groups have initial + // value set correctly, (-INF for max() and NaN for min()) and NaN is + // considered greater than INF. Also tests for when floating points are + // nested inside complex types. + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kSNaN = std::numeric_limits::signaling_NaN(); + static const T kInf = std::numeric_limits::infinity(); + + auto data = makeRowVector({ + // regular ordering + makeFlatVector({2.0, kNaN, 1.1, kInf, -1.1}), + // with nulls + makeNullableFlatVector({2.0, kNaN, std::nullopt, 1.1, -1.1}), + // only nans (use a different binary representation for NaN to verify + // that they are considered equal) + makeFlatVector({kSNaN, kSNaN, kSNaN, kSNaN, kSNaN}), + // only Inf + makeFlatVector({kInf, kInf, kInf, kInf, kInf}), + // only -Inf + makeFlatVector({-kInf, -kInf, -kInf, -kInf, -kInf}), + // group by column + makeFlatVector({1, 1, 1, 2, 2}), + }); + + // Global aggregation. + { + auto expected = makeRowVector( + {makeFlatVector(std::vector({-1.1})), + makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({-1.1})), + makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({kInf})), + makeFlatVector(std::vector({kInf})), + makeFlatVector(std::vector({-kInf})), + makeFlatVector(std::vector({-kInf}))}); + + testAggregations( + {data}, + {}, + {"min(c0)", + "max(c0)", + "min(c1)", + "max(c1)", + "min(c2)", + "max(c2)", + "min(c3)", + "max(c3)", + "min(c4)", + "max(c4)"}, + {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector( + {makeFlatVector({1, 2}), + makeFlatVector({1.1, -1.1}), + makeFlatVector({kNaN, kInf}), + makeFlatVector({2.0, -1.1}), + makeFlatVector({kNaN, 1.1}), + makeFlatVector({kNaN, kNaN}), + makeFlatVector({kNaN, kNaN}), + makeFlatVector({kInf, kInf}), + makeFlatVector({kInf, kInf}), + makeFlatVector({-kInf, -kInf}), + makeFlatVector({-kInf, -kInf})}); + + testAggregations( + {data}, + {"c5"}, + {"min(c0)", + "max(c0)", + "min(c1)", + "max(c1)", + "min(c2)", + "max(c2)", + "min(c3)", + "max(c3)", + "min(c4)", + "max(c4)"}, + {expected}); + } + + // Test for float point values nested inside complex type. + data = makeRowVector({ + makeRowVector({ + makeFlatVector({2, kNaN, 1, kInf, -1, kNaN}), + makeFlatVector({1, 1, 1, 2, 2, 2}), + }), + makeFlatVector({1, 1, 1, 2, 2, 2}), + }); + + // Global aggregation. + { + auto expected = makeRowVector( + {makeRowVector({ + makeFlatVector(std::vector({-1})), + makeFlatVector(std::vector({2})), + }), + makeRowVector({ + makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({2})), + })}); + + testAggregations({data}, {}, {"min(c0)", "max(c0)"}, {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector( + {makeFlatVector({1, 2}), + makeRowVector({ + makeFlatVector(std::vector({1, -1})), + makeFlatVector(std::vector({1, 2})), + }), + makeRowVector({ + makeFlatVector(std::vector({kNaN, kNaN})), + makeFlatVector(std::vector({1, 2})), + })}); + + testAggregations({data}, {"c1"}, {"min(c0)", "max(c0)"}, {expected}); + } + } }; TEST_F(MinMaxTest, maxTinyint) { @@ -124,10 +253,12 @@ TEST_F(MinMaxTest, maxBigint) { TEST_F(MinMaxTest, maxReal) { doTest(max, REAL()); + testExtremeFloatValues(); } TEST_F(MinMaxTest, maxDouble) { doTest(max, DOUBLE()); + testExtremeFloatValues(); } TEST_F(MinMaxTest, maxVarchar) { @@ -936,6 +1067,59 @@ class MinMaxNTest : public functions::aggregate::test::AggregationTestBase { {"min(c1, 2)", "min(c1, 5)", "max(c1, 3)", "max(c1, 7)"}, {expected}); } + + template + void testNaNFloatValues() { + // Tests to ensure NaN is correctly handled and considered greater than + // Infinity. + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kInf = std::numeric_limits::infinity(); + + auto data = makeRowVector({ + // regular ordering + makeFlatVector({2.0, kNaN, kInf, kNaN, -1.1, 0.0}), + // with nulls (null is ignored) + makeNullableFlatVector({2.0, kNaN, std::nullopt, 1.1, -1.1, 0.0}), + // group by column + makeFlatVector({1, 1, 1, 2, 2, 2}), + }); + + // Global aggregation. + { + auto expected = makeRowVector( + {makeArrayVector({{-1.1, 0.0, 2.0, kInf, kNaN, kNaN}}), + makeArrayVector({{kNaN, kNaN, kInf, 2.0, 0.0, -1.1}}), + makeArrayVector({{-1.1, 0.0, 1.1, 2.0, kNaN}}), + makeArrayVector({{kNaN, 2.0, 1.1, 0.0, -1.1}})}); + + testAggregations( + {data}, + {}, + { + "min(c0, 6)", + "max(c0, 6)", + "min(c1, 6)", + "max(c1, 6)", + }, + {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector( + {makeFlatVector({1, 2}), + makeArrayVector({{2.0, kInf, kNaN}, {-1.1, 0.0, kNaN}}), + makeArrayVector({{kNaN, kInf, 2.0}, {kNaN, 0.0, -1.1}}), + makeArrayVector({{2.0, kNaN}, {-1.1, 0.0, 1.1}}), + makeArrayVector({{kNaN, 2.0}, {1.1, 0.0, -1.1}})}); + + testAggregations( + {data}, + {"c2"}, + {"min(c0, 3)", "max(c0, 3)", "min(c1, 3)", "max(c1, 3)"}, + {expected}); + } + } }; TEST_F(MinMaxNTest, tinyint) { @@ -961,11 +1145,13 @@ TEST_F(MinMaxNTest, bigint) { TEST_F(MinMaxNTest, real) { testNumericGlobal(); testNumericGroupBy(); + testNaNFloatValues(); } TEST_F(MinMaxNTest, double) { testNumericGlobal(); testNumericGroupBy(); + testNaNFloatValues(); } TEST_F(MinMaxNTest, shortdecimal) {