Skip to content

Commit

Permalink
Fix NaN handling for multimap_agg
Browse files Browse the repository at this point in the history
Summary:
Highlights of the this change:
- Ensure multimap_agg treats all NaN binary representations are equal
- Introduces a utility class for floating point types that provide
comparator and hash functor to implement consistent behavior or NaNs
across the codebase. These can be passed to standard containers and
functions like std::map and std::sort, etc.
- Fix NaN handling of floating points for map variants which uses a
std::map to hold key and value variants. Without this, using NaN as
a key would result in an inconsistent state because once NaN is added
as the first value, it cannot compare NaN with anything else or if NaN
is added after some initial regular values, then it does not accept
NaN for the same reason.

Differential Revision: D57187815
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 9, 2024
1 parent 49c3ebb commit 34e7295
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 50 deletions.
2 changes: 1 addition & 1 deletion velox/functions/prestosql/ArithmeticImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <cmath>
#include <type_traits>
#include "folly/CPortability.h"
#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {

Expand Down
18 changes: 18 additions & 0 deletions velox/functions/prestosql/aggregates/MultiMapAggAggregate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "velox/exec/Strings.h"
#include "velox/functions/lib/aggregates/ValueList.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/vector/FlatVector.h"

namespace facebook::velox::aggregate::prestosql {
Expand Down Expand Up @@ -232,6 +233,23 @@ struct MultiMapAccumulatorTypeTraits {
using AccumulatorType = MultiMapAccumulator<T>;
};

// Ensure Accumulator treats NaNs as equal.
template <>
struct MultiMapAccumulatorTypeTraits<float> {
using AccumulatorType = MultiMapAccumulator<
float,
FloatingPointUtil<float>::NaNAwareHash,
FloatingPointUtil<float>::NaNAwareEquals>;
};

template <>
struct MultiMapAccumulatorTypeTraits<double> {
using AccumulatorType = MultiMapAccumulator<
double,
FloatingPointUtil<double>::NaNAwareHash,
FloatingPointUtil<double>::NaNAwareEquals>;
};

template <>
struct MultiMapAccumulatorTypeTraits<ComplexType> {
using AccumulatorType = ComplexTypeMultiMapAccumulator;
Expand Down
36 changes: 36 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MultiMapAggTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cmath>
#include "velox/exec/tests/utils/AssertQueryBuilder.h"
#include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h"

Expand Down Expand Up @@ -266,5 +267,40 @@ TEST_F(MultiMapAggTest, arrayKeyGroupBy) {
{expected});
}

TEST_F(MultiMapAggTest, doubleKeyGlobal) {
// Verify that all NaN representations used as a map key are treated as equal
static const double KNan1 = std::nan("1");
static const double KNan2 = std::nan("2");
auto data = makeRowVector({
makeFlatVector<double>(
{KNan1, KNan2, 1.1, 0.2, 23.0, 2.0, 23.0, 2.0, 1.1, 0.2, 23.0, 2.0}),
makeNullableFlatVector<int64_t>(
{-2, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}),
});

auto expected = makeRowVector({
makeMapVector(
{
0,
},
makeFlatVector<double>({KNan1, 0.2, 1.1, 2.0, 23.0}),
makeArrayVector<int64_t>({
{-2, -1},
{1, 7},
{0, 6},
{3, 5, 9},
{2, 4, 8},
})),
});

testAggregations(
{data},
{},
{"multimap_agg(c0, c1)"},
// Sort the result arrays to ensure deterministic results.
{"transform_values(a0, (k, v) -> array_sort(v))"},
{expected});
}

} // namespace
} // namespace facebook::velox::aggregate::prestosql
2 changes: 1 addition & 1 deletion velox/type/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ add_library(
velox_type
Conversions.cpp
DecimalUtil.cpp
DoubleUtil.cpp
Filter.cpp
FloatingPointUtil.cpp
HugeInt.cpp
StringView.cpp
StringView.h
Expand Down
29 changes: 0 additions & 29 deletions velox/type/DoubleUtil.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include <array>

namespace facebook::velox {
Expand Down
76 changes: 76 additions & 0 deletions velox/type/FloatingPointUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <array>
#include <cmath>
#include <vector>

namespace facebook::velox {

/// A utility class that holds custom comparator and hash functors for floating
/// point types. These are designed to ensure consistent NaN handling according
/// to the following rules:
/// - NaN == NaN returns true, even for NaNs with differing binary
/// representations.
/// - NaN is considered greater than infinity.
/// These can be passed to standard containers and functions like std::map and
/// std::sort, etc.
template <typename FLOAT>
class FloatingPointUtil {
public:
static_assert(
std::is_floating_point<FLOAT>::value,
"A floating point type is required.");

struct NaNAwareEquals {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs == rhs;
}
};

struct NaNAwareLessThan {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (!std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs < rhs;
}
};

struct NaNAwareHash {
std::size_t operator()(const FLOAT& val) const noexcept {
static const std::size_t kNanHash =
std::hash<FLOAT>{}(std::numeric_limits<FLOAT>::quiet_NaN());
if (std::isnan(val)) {
return kNanHash;
}
return std::hash<FLOAT>{}(val);
}
};
};

/// A static class that holds helper functions for DOUBLE type.
class DoubleUtil {
public:
static const std::array<double, 309> kPowersOfTen;

}; // DoubleUtil
} // namespace facebook::velox
42 changes: 40 additions & 2 deletions velox/type/Variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "folly/json.h"
#include "velox/common/encode/Base64.h"
#include "velox/type/DecimalUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/type/HugeInt.h"

namespace facebook::velox {
Expand Down Expand Up @@ -670,6 +671,42 @@ variant variant::create(const folly::dynamic& variantobj) {
}
}

template <TypeKind KIND>
bool variant::lessThan(const variant& a, const variant& b) const {
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return FloatingPointUtil<float>::NaNAwareLessThan{}(
a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return FloatingPointUtil<double>::NaNAwareLessThan{}(
a.value<KIND>(), b.value<KIND>());
}
return a.value<KIND>() < b.value<KIND>();
}

template <TypeKind KIND>
bool variant::equals(const variant& a, const variant& b) const {
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return FloatingPointUtil<float>::NaNAwareEquals{}(
a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return FloatingPointUtil<double>::NaNAwareEquals{}(
a.value<KIND>(), b.value<KIND>());
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}

uint64_t variant::hash() const {
uint64_t hash = 0;
if (isNull()) {
Expand All @@ -690,9 +727,10 @@ uint64_t variant::hash() const {
case TypeKind::BOOLEAN:
return folly::Hash{}(value<TypeKind::BOOLEAN>());
case TypeKind::REAL:
return folly::Hash{}(value<TypeKind::REAL>());
return FloatingPointUtil<float>::NaNAwareHash{}(value<TypeKind::REAL>());
case TypeKind::DOUBLE:
return folly::Hash{}(value<TypeKind::DOUBLE>());
return FloatingPointUtil<double>::NaNAwareHash{}(
value<TypeKind::DOUBLE>());
case TypeKind::VARBINARY:
return folly::Hash{}(value<TypeKind::VARBINARY>());
case TypeKind::VARCHAR:
Expand Down
18 changes: 2 additions & 16 deletions velox/type/Variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,10 @@ class variant {
variant(TypeKind kind, void* ptr) : kind_{kind}, ptr_{ptr} {}

template <TypeKind KIND>
bool lessThan(const variant& a, const variant& b) const {
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
return a.value<KIND>() < b.value<KIND>();
}
bool lessThan(const variant& a, const variant& b) const;

template <TypeKind KIND>
bool equals(const variant& a, const variant& b) const {
if (a.isNull() || b.isNull()) {
return false;
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}
bool equals(const variant& a, const variant& b) const;

template <TypeKind KIND>
void typedDestroy() {
Expand Down
27 changes: 27 additions & 0 deletions velox/type/tests/VariantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ TEST(VariantTest, equalsWithEpsilonFloat) {
ASSERT_FALSE(variant(sum1).equalsWithEpsilon(variant(sum3)));
}

TEST(VariantTest, mapWithNaNKey) {
// Verify that map variants treat all NaN keys as equivalent and comparable
// (consider them the largest) with other values.
static const double KNan = std::numeric_limits<double>::quiet_NaN();
auto mapType = MAP(DOUBLE(), INTEGER());
{
// NaN added at the start of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
{
// NaN added in the middle of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
}

TEST(VariantTest, serialize) {
// Null values.
testSerDe(variant(TypeKind::BOOLEAN));
Expand Down

0 comments on commit 34e7295

Please sign in to comment.