Skip to content

Commit

Permalink
Fix handling of NaN for map variants (facebookincubator#9764)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#9764

Highlights of the this change:
- Introduces a utility functors for floating point types that provide
comparator and hash functor to implement consistent behavior or NaNs
across the codebase. These can be passed to standard containers and
functions like std::map and std::sort, etc.
- These utility functors will be used in upcoming changes to fixes for NaN
behavior across the codebase wherever folly or std containers or
functions are used that can accept such functors.
- Fix NaN handling of floating points for map variants which uses a
std::map to hold key and value variants. Without this, using NaN as
a key would result in an inconsistent state because NaN cannot be
compared with other values. This resulted in inconsistent comparison
of map vectors in tests where QueryAssertions.h uses a map variant
to compare such vectors. (See asserEqualResults())

Differential Revision: D57187815
  • Loading branch information
bikramSingh91 authored and facebook-github-bot committed May 10, 2024
1 parent 8184113 commit 6a7906f
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 65 deletions.
2 changes: 1 addition & 1 deletion velox/functions/prestosql/ArithmeticImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <cmath>
#include <type_traits>
#include "folly/CPortability.h"
#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {

Expand Down
2 changes: 1 addition & 1 deletion velox/type/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ add_library(
velox_type
Conversions.cpp
DecimalUtil.cpp
DoubleUtil.cpp
Filter.cpp
FloatingPointUtil.cpp
HugeInt.cpp
StringView.cpp
StringView.h
Expand Down
29 changes: 0 additions & 29 deletions velox/type/DoubleUtil.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

#include "velox/type/DoubleUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include <array>

namespace facebook::velox {
Expand Down
82 changes: 82 additions & 0 deletions velox/type/FloatingPointUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <array>
#include <cmath>
#include <vector>

namespace facebook::velox {

/// Custom comparator and hash functors for floating point types. These are
/// designed to ensure consistent NaN handling according to the following rules:
/// - NaN == NaN returns true, even for NaNs with differing binary
/// representations.
/// - NaN is considered greater than infinity.
/// These can be passed to standard containers and functions like std::map and
/// std::sort, etc.
namespace util::floating_point {
template <typename FLOAT>
struct NaNAwareEquals {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs == rhs;
}
};

template <typename FLOAT>
struct NaNAwareLessThan {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (!std::isnan(lhs) && std::isnan(rhs)) {
return true;
}
return lhs < rhs;
}
};

template <typename FLOAT>
struct NaNAwareGreaterThan {
bool operator()(const FLOAT& lhs, const FLOAT& rhs) const {
if (std::isnan(lhs) && !std::isnan(rhs)) {
return true;
}
return lhs > rhs;
}
};

template <typename FLOAT>
struct NaNAwareHash {
std::size_t operator()(const FLOAT& val) const noexcept {
static const std::size_t kNanHash =
std::hash<FLOAT>{}(std::numeric_limits<FLOAT>::quiet_NaN());
if (std::isnan(val)) {
return kNanHash;
}
return std::hash<FLOAT>{}(val);
}
};
} // namespace util::floating_point

/// A static class that holds helper functions for DOUBLE type.
class DoubleUtil {
public:
static const std::array<double, 309> kPowersOfTen;

}; // DoubleUtil
} // namespace facebook::velox
57 changes: 55 additions & 2 deletions velox/type/Variant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "folly/json.h"
#include "velox/common/encode/Base64.h"
#include "velox/type/DecimalUtil.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/type/HugeInt.h"

namespace facebook::velox {
Expand Down Expand Up @@ -670,7 +671,59 @@ variant variant::create(const folly::dynamic& variantobj) {
}
}

template <TypeKind KIND>
bool variant::lessThan(const variant& a, const variant& b) const {
using namespace facebook::velox::util::floating_point;
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return NaNAwareLessThan<float>{}(a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return NaNAwareLessThan<double>{}(a.value<KIND>(), b.value<KIND>());
}
return a.value<KIND>() < b.value<KIND>();
}

bool variant::operator<(const variant& other) const {
if (other.kind_ != this->kind_) {
return other.kind_ < this->kind_;
}
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(lessThan, kind_, *this, other);
}

template <TypeKind KIND>
bool variant::equals(const variant& a, const variant& b) const {
using namespace facebook::velox::util::floating_point;
if (a.isNull() || b.isNull()) {
return false;
}
if constexpr (KIND == TypeKind::REAL) {
return NaNAwareEquals<float>{}(a.value<KIND>(), b.value<KIND>());
}
if constexpr (KIND == TypeKind::DOUBLE) {
return NaNAwareEquals<double>{}(a.value<KIND>(), b.value<KIND>());
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}

bool variant::equals(const variant& other) const {
if (other.kind_ != this->kind_) {
return false;
}
if (other.isNull()) {
return this->isNull();
}
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(equals, kind_, *this, other);
}

uint64_t variant::hash() const {
using namespace facebook::velox::util::floating_point;
uint64_t hash = 0;
if (isNull()) {
return folly::Hash{}(static_cast<int32_t>(kind_));
Expand All @@ -690,9 +743,9 @@ uint64_t variant::hash() const {
case TypeKind::BOOLEAN:
return folly::Hash{}(value<TypeKind::BOOLEAN>());
case TypeKind::REAL:
return folly::Hash{}(value<TypeKind::REAL>());
return NaNAwareHash<float>{}(value<TypeKind::REAL>());
case TypeKind::DOUBLE:
return folly::Hash{}(value<TypeKind::DOUBLE>());
return NaNAwareHash<double>{}(value<TypeKind::DOUBLE>());
case TypeKind::VARBINARY:
return folly::Hash{}(value<TypeKind::VARBINARY>());
case TypeKind::VARCHAR:
Expand Down
35 changes: 4 additions & 31 deletions velox/type/Variant.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,24 +130,10 @@ class variant {
variant(TypeKind kind, void* ptr) : kind_{kind}, ptr_{ptr} {}

template <TypeKind KIND>
bool lessThan(const variant& a, const variant& b) const {
if (a.isNull() && !b.isNull()) {
return true;
}
if (a.isNull() || b.isNull()) {
return false;
}
return a.value<KIND>() < b.value<KIND>();
}
bool lessThan(const variant& a, const variant& b) const;

template <TypeKind KIND>
bool equals(const variant& a, const variant& b) const {
if (a.isNull() || b.isNull()) {
return false;
}
// todo(youknowjack): centralize equality semantics
return a.value<KIND>() == b.value<KIND>();
}
bool equals(const variant& a, const variant& b) const;

template <TypeKind KIND>
void typedDestroy() {
Expand Down Expand Up @@ -367,22 +353,9 @@ class variant {
return *this;
}

bool operator<(const variant& other) const {
if (other.kind_ != this->kind_) {
return other.kind_ < this->kind_;
}
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(lessThan, kind_, *this, other);
}
bool operator<(const variant& other) const;

bool equals(const variant& other) const {
if (other.kind_ != this->kind_) {
return false;
}
if (other.isNull()) {
return this->isNull();
}
return VELOX_DYNAMIC_TYPE_DISPATCH_ALL(equals, kind_, *this, other);
}
bool equals(const variant& other) const;

bool equalsWithNullEqualsNull(const variant& other) const {
if (other.kind_ != this->kind_) {
Expand Down
27 changes: 27 additions & 0 deletions velox/type/tests/VariantTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,33 @@ TEST(VariantTest, equalsWithEpsilonFloat) {
ASSERT_FALSE(variant(sum1).equalsWithEpsilon(variant(sum3)));
}

TEST(VariantTest, mapWithNaNKey) {
// Verify that map variants treat all NaN keys as equivalent and comparable
// (consider them the largest) with other values.
static const double KNan = std::numeric_limits<double>::quiet_NaN();
auto mapType = MAP(DOUBLE(), INTEGER());
{
// NaN added at the start of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
{
// NaN added in the middle of insertions.
std::map<variant, variant> mapVariant;
mapVariant.insert({variant(1.2), variant(2)});
mapVariant.insert({variant(KNan), variant(1)});
mapVariant.insert({variant(12.4), variant(3)});
EXPECT_EQ(
"[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]",
variant::map(mapVariant).toJson(mapType));
}
}

TEST(VariantTest, serialize) {
// Null values.
testSerDe(variant(TypeKind::BOOLEAN));
Expand Down

0 comments on commit 6a7906f

Please sign in to comment.