diff --git a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp index 50e1e085faf5..71f020ca398c 100644 --- a/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp +++ b/velox/functions/prestosql/aggregates/MinMaxByAggregates.cpp @@ -19,6 +19,7 @@ #include "velox/functions/lib/aggregates/MinMaxByAggregatesBase.h" #include "velox/functions/lib/aggregates/ValueSet.h" #include "velox/functions/prestosql/aggregates/AggregateNames.h" +#include "velox/type/FloatingPointUtil.h" using namespace facebook::velox::functions::aggregate; @@ -40,8 +41,16 @@ struct Comparator { return true; } if constexpr (greaterThan) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareGreaterThan{}( + newComparisons.valueAt(index), *accumulator); + } return newComparisons.valueAt(index) > *accumulator; } else { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareLessThan{}( + newComparisons.valueAt(index), *accumulator); + } return newComparisons.valueAt(index) < *accumulator; } } else { @@ -560,10 +569,16 @@ template struct Less { using Pair = std::pair>; bool operator()(const Pair& lhs, const Pair& rhs) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareLessThan{}(lhs.first, rhs.first); + } return lhs.first < rhs.first; } bool compare(C lhs, const Pair& rhs) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareLessThan{}(lhs, rhs.first); + } return lhs < rhs.first; } }; @@ -572,10 +587,17 @@ template struct Greater { using Pair = std::pair>; bool operator()(const Pair& lhs, const Pair& rhs) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareGreaterThan{}( + lhs.first, rhs.first); + } return lhs.first > rhs.first; } bool compare(C lhs, const Pair& rhs) { + if constexpr (std::is_floating_point_v) { + return util::floating_point::NaNAwareGreaterThan{}(lhs, rhs.first); + } return lhs > rhs.first; } }; diff --git a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp index ecac491c44d4..8f3285b95e15 100644 --- a/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/MinMaxByAggregationTest.cpp @@ -1444,8 +1444,76 @@ class MinMaxByNTest : public AggregationTestBase { AggregationTestBase::SetUp(); AggregationTestBase::enableTestStreaming(); } + + template + void testNanFloatValues() { + // Verify that NaN values are handeled correctly as being greater than + // infinity. + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kSNaN = std::numeric_limits::signaling_NaN(); + static const T kInf = std::numeric_limits::infinity(); + + auto data = makeRowVector({ + // output column for min_by/max_by + makeFlatVector({1, 2, 3, 4, 5}), + // group by column + makeFlatVector({1, 1, 2, 2, 2}), + // regular ordering + makeFlatVector({2.0, kNaN, 1.1, kInf, -1.1}), + // with nulls + makeNullableFlatVector({2.0, 1.1, std::nullopt, kSNaN, -1.1}), + }); + + // Global aggregation. + { + auto expected = makeRowVector({ + makeArrayVectorFromJson({"[2, 4]"}), + makeArrayVectorFromJson({"[4, 1]"}), + makeArrayVectorFromJson({"[5, 3]"}), + makeArrayVectorFromJson({"[5, 2]"}), + }); + + testAggregations( + {data}, + {}, + { + "max_by(c0, c2, 2)", + "max_by(c0, c3, 2)", + "min_by(c0, c2, 2)", + "min_by(c0, c3, 2)", + }, + {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector({ + makeFlatVector({1, 2}), // grouping key + makeArrayVectorFromJson({"[2, 1]", "[4, 3]"}), + makeArrayVectorFromJson({"[1, 2]", "[4, 5]"}), + makeArrayVectorFromJson({"[1, 2]", "[5, 3]"}), + makeArrayVectorFromJson({"[2, 1]", "[5, 4]"}), + }); + + testAggregations( + {data}, + {"c1"}, + { + "max_by(c0, c2, 2)", + "max_by(c0, c3, 2)", + "min_by(c0, c2, 2)", + "min_by(c0, c3, 2)", + }, + {expected}); + } + } }; +TEST_F(MinMaxByNTest, nans) { + testNanFloatValues(); + testNanFloatValues(); +} + TEST_F(MinMaxByNTest, global) { // DuckDB doesn't support 3-argument versions of min_by and max_by. @@ -2331,5 +2399,110 @@ TEST_F(MinMaxByNTest, peakMemory) { EXPECT_LT(maxByPeak, 190000); } +class MinMaxByTest : public AggregationTestBase { + public: + template + void testNanFloatValues() { + // Verify that NaN values are handeled correctly as being greater than + // infinity. + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kSNaN = std::numeric_limits::signaling_NaN(); + static const T kInf = std::numeric_limits::infinity(); + + auto data = makeRowVector({ + // output column for min_by/max_by + makeFlatVector({1, 2, 3, 4, 5}), + // group by column + makeFlatVector({1, 1, 2, 2, 2}), + // regular ordering + makeFlatVector({2.0, kNaN, 1.1, kInf, -1.1}), + // with nulls + makeNullableFlatVector({2.0, 1.1, std::nullopt, kSNaN, -1.1}), + }); + + // Global aggregation. + { + auto expected = makeRowVector({ + makeFlatVector(std::vector({2})), + makeFlatVector(std::vector({4})), + makeFlatVector(std::vector({5})), + makeFlatVector(std::vector({5})), + }); + + testAggregations( + {data}, + {}, + { + "max_by(c0, c2)", + "max_by(c0, c3)", + "min_by(c0, c2)", + "min_by(c0, c3)", + }, + {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector({ + makeFlatVector({1, 2}), // grouping key + makeFlatVector({2, 4}), + makeFlatVector({1, 4}), + makeFlatVector({1, 5}), + makeFlatVector({2, 5}), + }); + + testAggregations( + {data}, + {"c1"}, + { + "max_by(c0, c2)", + "max_by(c0, c3)", + "min_by(c0, c2)", + "min_by(c0, c3)", + }, + {expected}); + } + + // Test for float point values nested inside complex type. + data = makeRowVector({ + // output column for min_by/max_by + makeFlatVector({1, 2, 3, 4, 5, 6}), + // group by column + makeFlatVector({1, 1, 2, 2, 2, 2}), + makeRowVector({ + makeFlatVector({2, kNaN, 1, kInf, -1, kNaN}), + makeFlatVector({1, 1, 1, 2, 2, 2}), + }), + }); + + // Global aggregation. + { + auto expected = makeRowVector({ + makeFlatVector(std::vector({6})), + makeFlatVector(std::vector({5})), + }); + + testAggregations( + {data}, {}, {"max_by(c0, c2)", "min_by(c0, c2)"}, {expected}); + } + + // group-by aggregation. + { + auto expected = makeRowVector({ + makeFlatVector({1, 2}), // grouping key + makeFlatVector({2, 6}), + makeFlatVector({1, 5}), + }); + + testAggregations( + {data}, {"c1"}, {"max_by(c0, c2)", "min_by(c0, c2)"}, {expected}); + } + } +}; + +TEST_F(MinMaxByTest, nans) { + testNanFloatValues(); + testNanFloatValues(); +} } // namespace } // namespace facebook::velox::aggregate::test