Skip to content

Commit

Permalink
Fix NaN handling for multimap_agg (facebookincubator#9769)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#9769

Ensure multimap_agg treats all NaN keys as equal

Differential Revision: D57217450
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 10, 2024
1 parent 6a7906f commit 4c42254
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
18 changes: 18 additions & 0 deletions velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/exec/Strings.h"
#include "velox/functions/lib/aggregates/ValueList.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::aggregate::prestosql {
Expand Down Expand Up @@ -232,6 +233,23 @@ struct MultiMapAccumulatorTypeTraits {
using AccumulatorType = MultiMapAccumulator<T>;
};

// Ensure Accumulator treats NaNs as equal.
template <>
struct MultiMapAccumulatorTypeTraits<float> {
using AccumulatorType = MultiMapAccumulator<
float,
util::floating_point::NaNAwareHash<float>,
util::floating_point::NaNAwareEquals<float>>;
};

template <>
struct MultiMapAccumulatorTypeTraits<double> {
using AccumulatorType = MultiMapAccumulator<
double,
util::floating_point::NaNAwareHash<double>,
util::floating_point::NaNAwareEquals<double>>;
};

template <>
struct MultiMapAccumulatorTypeTraits<ComplexType> {
using AccumulatorType = ComplexTypeMultiMapAccumulator;
Expand Down
36 changes: 36 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"

Expand Down Expand Up @@ -266,5 +267,40 @@ TEST_F(MultiMapAggTest, arrayKeyGroupBy) {
{expected});
}

TEST_F(MultiMapAggTest, doubleKeyGlobal) {
// Verify that all NaN representations used as a map key are treated as equal
static const double KNan1 = std::nan("1");
static const double KNan2 = std::nan("2");
auto data = makeRowVector({
makeFlatVector<double>(
{KNan1, KNan2, 1.1, 0.2, 23.0, 2.0, 23.0, 2.0, 1.1, 0.2, 23.0, 2.0}),
makeNullableFlatVector<int64_t>(
{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
});

auto expected = makeRowVector({
makeMapVector(
{
0,
},
makeFlatVector<double>({KNan1, 0.2, 1.1, 2.0, 23.0}),
makeArrayVector<int64_t>({
{-2, -1},
{1, 7},
{0, 6},
{3, 5, 9},
{2, 4, 8},
})),
});

testAggregations(
{data},
{},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

} // namespace
} // namespace facebook::velox::aggregate::prestosql

0 comments on commit 4c42254

Please sign in to comment.