Skip to content

Commit

Permalink
Fix handling of NaN for map variants (facebookincubator#9764)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#9764

Highlights of the this change:
- Introduces a utility class for floating point types that provide
comparator and hash functor to implement consistent behavior or NaNs
across the codebase. These can be passed to standard containers and
functions like std::map and std::sort, etc.
- This utility class will be used in upcoming changes to fixes for NaN
behavior across the codebase wherever folly or std containers or
functions are used that can accept such functors.
- Fix NaN handling of floating points for map variants which uses a
std::map to hold key and value variants. Without this, using NaN as
a key would result in an inconsistent state because NaN cannot be
compared with other values. This resulted in inconsistent comparison
of map vectors in tests where QueryAssertions.h uses a map variant
to compare such vectors. (See asserEqualResults())

Differential Revision: D57187815
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 10, 2024
1 parent f1b6ccf commit 59283f5
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 50 deletions.
2 changes: 1 addition & 1 deletion velox/functions/prestosql/ArithmeticImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <cmath>
#include <type_traits>
#include "folly/CPortability.h"
#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {

Expand Down
2 changes: 1 addition & 1 deletion velox/type/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ add_library(
velox_type
Conversions.cpp
DecimalUtil.cpp
DoubleUtil.cpp
Filter.cpp
FloatingPointUtil.cpp
HugeInt.cpp
StringView.cpp
StringView.h
Expand Down
29 changes: 0 additions & 29 deletions velox/type/DoubleUtil.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include <array>

namespace facebook::velox {
Expand Down
76 changes: 76 additions & 0 deletions velox/type/FloatingPointUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <array>
#include <cmath>
#include <vector>

namespace facebook::velox {

/// A utility class that holds custom comparator and hash functors for floating
/// point types. These are designed to ensure consistent NaN handling according
/// to the following rules:
/// - NaN == NaN returns true, even for NaNs with differing binary
/// representations.
/// - NaN is considered greater than infinity.
/// These can be passed to standard containers and functions like std::map and
/// std::sort, etc.
template <typename FLOAT>
class FloatingPointUtil {
public:
static_assert(
std::is_floating_point<FLOAT>::value,
"A floating point type is required.");

struct NaNAwareEquals {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs == rhs;
}
};

struct NaNAwareLessThan {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (!std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs < rhs;
}
};

struct NaNAwareHash {
std::size_t operator()(const FLOAT& val) const noexcept {
static const std::size_t kNanHash =
std::hash<FLOAT>{}(std::numeric_limits<FLOAT>::quiet_NaN());
if (std::isnan(val)) {
return kNanHash;
}
return std::hash<FLOAT>{}(val);
}
};
};

/// A static class that holds helper functions for DOUBLE type.
class DoubleUtil {
public:
static const std::array<double, 309> kPowersOfTen;

}; // DoubleUtil
} // namespace facebook::velox
42 changes: 40 additions & 2 deletions velox/type/Variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "folly/json.h"
#include "velox/common/encode/Base64.h"
#include "velox/type/DecimalUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/type/HugeInt.h"

namespace facebook::velox {
Expand Down Expand Up @@ -670,6 +671,42 @@ variant variant::create(const folly::dynamic& variantobj) {
}
}

template <TypeKind KIND>
bool variant::lessThan(const variant& a, const variant& b) const {
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return FloatingPointUtil<float>::NaNAwareLessThan{}(
a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return FloatingPointUtil<double>::NaNAwareLessThan{}(
a.value<KIND>(), b.value<KIND>());
}
return a.value<KIND>() < b.value<KIND>();
}

template <TypeKind KIND>
bool variant::equals(const variant& a, const variant& b) const {
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return FloatingPointUtil<float>::NaNAwareEquals{}(
a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return FloatingPointUtil<double>::NaNAwareEquals{}(
a.value<KIND>(), b.value<KIND>());
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}

uint64_t variant::hash() const {
uint64_t hash = 0;
if (isNull()) {
Expand All @@ -690,9 +727,10 @@ uint64_t variant::hash() const {
case TypeKind::BOOLEAN:
return folly::Hash{}(value<TypeKind::BOOLEAN>());
case TypeKind::REAL:
return folly::Hash{}(value<TypeKind::REAL>());
return FloatingPointUtil<float>::NaNAwareHash{}(value<TypeKind::REAL>());
case TypeKind::DOUBLE:
return folly::Hash{}(value<TypeKind::DOUBLE>());
return FloatingPointUtil<double>::NaNAwareHash{}(
value<TypeKind::DOUBLE>());
case TypeKind::VARBINARY:
return folly::Hash{}(value<TypeKind::VARBINARY>());
case TypeKind::VARCHAR:
Expand Down
18 changes: 2 additions & 16 deletions velox/type/Variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,10 @@ class variant {
variant(TypeKind kind, void* ptr) : kind_{kind}, ptr_{ptr} {}

template <TypeKind KIND>
bool lessThan(const variant& a, const variant& b) const {
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
return a.value<KIND>() < b.value<KIND>();
}
bool lessThan(const variant& a, const variant& b) const;

template <TypeKind KIND>
bool equals(const variant& a, const variant& b) const {
if (a.isNull() || b.isNull()) {
return false;
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}
bool equals(const variant& a, const variant& b) const;

template <TypeKind KIND>
void typedDestroy() {
Expand Down
27 changes: 27 additions & 0 deletions velox/type/tests/VariantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ TEST(VariantTest, equalsWithEpsilonFloat) {
ASSERT_FALSE(variant(sum1).equalsWithEpsilon(variant(sum3)));
}

TEST(VariantTest, mapWithNaNKey) {
// Verify that map variants treat all NaN keys as equivalent and comparable
// (consider them the largest) with other values.
static const double KNan = std::numeric_limits<double>::quiet_NaN();
auto mapType = MAP(DOUBLE(), INTEGER());
{
// NaN added at the start of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
{
// NaN added in the middle of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
}

TEST(VariantTest, serialize) {
// Null values.
testSerDe(variant(TypeKind::BOOLEAN));
Expand Down

0 comments on commit 59283f5

Please sign in to comment.