Skip to content

Commit

Permalink
Fix NaN handling for min/max aggreates pushed down to scan
Browse files Browse the repository at this point in the history
Summary: This fixes min/max aggregates to handle NaN values correctly when they are pushed down to the scan operator. Specifically, the change ensures that NaN values are considered greater than infinity.

Reviewed By: zacw7

Differential Revision: D60297934
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Jul 26, 2024
1 parent 237ff41 commit 10e80fb
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 45 deletions.
18 changes: 13 additions & 5 deletions velox/exec/AggregationHook.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include "velox/common/base/CheckedArithmetic.h"
#include "velox/common/base/Range.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/vector/LazyVector.h"

namespace facebook::velox::aggregate {
Expand Down Expand Up @@ -231,11 +232,18 @@ class MinMaxHook final : public AggregationHook {

void addValue(vector_size_t row, const void* value) override {
auto group = findGroup(row);
if (clearNull(group) ||
(*reinterpret_cast<T*>(group + offset_) >
*reinterpret_cast<const T*>(value)) == isMin) {
*reinterpret_cast<T*>(group + offset_) =
*reinterpret_cast<const T*>(value);
T* currPtr = reinterpret_cast<T*>(group + offset_);
const T* valPtr = reinterpret_cast<const T*>(value);
if constexpr (std::is_floating_point_v<T>) {
static const auto isGreater =
util::floating_point::NaNAwareGreaterThan<T>{};
if (clearNull(group) || isGreater(*currPtr, *valPtr) == isMin) {
*currPtr = *valPtr;
}
} else {
if (clearNull(group) || (*currPtr > *valPtr) == isMin) {
*currPtr = *valPtr;
}
}
}
};
Expand Down
130 changes: 90 additions & 40 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@ class MinMaxTest : public functions::aggregate::test::AggregationTestBase {
// including, INF, -INF, NaN. This validates that the groups have initial
// value set correctly, (-INF for max() and NaN for min()) and NaN is
// considered greater than INF. Also tests for when floating points are
// nested inside complex types.
// nested inside complex types. Finally, this also tests for when
// aggregation is pushed down to the scan operator which can only happen if
// the column is a primitive type and not used anywhere execpt a single
// aggregate.
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();
static const T kInf = std::numeric_limits<T>::infinity();
Expand All @@ -134,63 +137,110 @@ class MinMaxTest : public functions::aggregate::test::AggregationTestBase {

// Global aggregation.
{
auto expected = makeRowVector(
{makeFlatVector<T>(std::vector<T>({-1.1})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({-1.1})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kInf})),
makeFlatVector<T>(std::vector<T>({kInf})),
makeFlatVector<T>(std::vector<T>({-kInf})),
makeFlatVector<T>(std::vector<T>({-kInf}))});
// Verify max pushed down to scan operator.
std::vector<VectorPtr> expectedMaxValues = {
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kInf})),
makeFlatVector<T>(std::vector<T>({-kInf}))};

testAggregations(
{data},
{},
{"min(c0)",
"max(c0)",
"min(c1)",
{"max(c0)", "max(c1)", "max(c2)", "max(c3)", "max(c4)"},
{makeRowVector(expectedMaxValues)});

// Verify max pushed down to scan operator.
std::vector<VectorPtr> expectedMinValues = {
makeFlatVector<T>(std::vector<T>({-1.1})),
makeFlatVector<T>(std::vector<T>({-1.1})),
makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<T>(std::vector<T>({kInf})),
makeFlatVector<T>(std::vector<T>({-kInf})),
};
testAggregations(
{data},
{},
{"min(c0)", "min(c1)", "min(c2)", "min(c3)", "min(c4)"},
{makeRowVector(expectedMinValues)});

// Verify max and min evaluated in aggregation operator.
std::vector<VectorPtr> allExpectedValues = expectedMaxValues;
allExpectedValues.insert(
allExpectedValues.end(),
expectedMinValues.begin(),
expectedMinValues.end());

testAggregations(
{data},
{},
{"max(c0)",
"max(c1)",
"min(c2)",
"max(c2)",
"min(c3)",
"max(c3)",
"min(c4)",
"max(c4)"},
{expected});
"max(c4)",
"min(c0)",
"min(c1)",
"min(c2)",
"min(c3)",
"min(c4)"},
{makeRowVector(allExpectedValues)});
}

// group-by aggregation.
{
auto expected = makeRowVector(
{makeFlatVector<int32_t>({1, 2}),
makeFlatVector<T>({1.1, -1.1}),
makeFlatVector<T>({kNaN, kInf}),
makeFlatVector<T>({2.0, -1.1}),
makeFlatVector<T>({kNaN, 1.1}),
makeFlatVector<T>({kNaN, kNaN}),
makeFlatVector<T>({kNaN, kNaN}),
makeFlatVector<T>({kInf, kInf}),
makeFlatVector<T>({kInf, kInf}),
makeFlatVector<T>({-kInf, -kInf}),
makeFlatVector<T>({-kInf, -kInf})});
// Verify max pushed down to scan operator.
std::vector<VectorPtr> expectedMaxValues = {
makeFlatVector<int32_t>({1, 2}), // grouping key
makeFlatVector<T>({kNaN, kInf}),
makeFlatVector<T>({kNaN, 1.1}),
makeFlatVector<T>({kNaN, kNaN}),
makeFlatVector<T>({kInf, kInf}),
makeFlatVector<T>({-kInf, -kInf})};

testAggregations(
{data},
{"c5"},
{"min(c0)",
"max(c0)",
"min(c1)",
{"max(c0)", "max(c1)", "max(c2)", "max(c3)", "max(c4)"},
{makeRowVector(expectedMaxValues)});

// Verify min pushed down to scan operator.
std::vector<VectorPtr> expectedMinValues = {
makeFlatVector<int32_t>({1, 2}), // grouping key
makeFlatVector<T>({1.1, -1.1}),
makeFlatVector<T>({2.0, -1.1}),
makeFlatVector<T>({kNaN, kNaN}),
makeFlatVector<T>({kInf, kInf}),
makeFlatVector<T>({-kInf, -kInf})};

testAggregations(
{data},
{"c5"},
{"min(c0)", "min(c1)", "min(c2)", "min(c3)", "min(c4)"},
{makeRowVector(expectedMinValues)});

// Verify max and min evaluated in aggregation operator.
std::vector<VectorPtr> allExpectedValues = expectedMaxValues;
allExpectedValues.insert(
allExpectedValues.end(),
expectedMinValues.begin() + 1, // skip the grouping key column
expectedMinValues.end());

testAggregations(
{data},
{"c5"},
{"max(c0)",
"max(c1)",
"min(c2)",
"max(c2)",
"min(c3)",
"max(c3)",
"min(c4)",
"max(c4)"},
{expected});
"max(c4)",
"min(c0)",
"min(c1)",
"min(c2)",
"min(c3)",
"min(c4)"},
{makeRowVector(allExpectedValues)});
}

// Test for float point values nested inside complex type.
Expand Down

0 comments on commit 10e80fb

Please sign in to comment.