diff --git a/velox/functions/prestosql/ArithmeticImpl.h b/velox/functions/prestosql/ArithmeticImpl.h index 11fe000d7265..6ea7d532aae7 100644 --- a/velox/functions/prestosql/ArithmeticImpl.h +++ b/velox/functions/prestosql/ArithmeticImpl.h @@ -19,7 +19,7 @@ #include #include #include "folly/CPortability.h" -#include "velox/type/DoubleUtil.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { diff --git a/velox/type/CMakeLists.txt b/velox/type/CMakeLists.txt index 326013c68c21..640a1458e187 100644 --- a/velox/type/CMakeLists.txt +++ b/velox/type/CMakeLists.txt @@ -23,8 +23,8 @@ add_library( velox_type Conversions.cpp DecimalUtil.cpp - DoubleUtil.cpp Filter.cpp + FloatingPointUtil.cpp HugeInt.cpp StringView.cpp StringView.h diff --git a/velox/type/DoubleUtil.h b/velox/type/DoubleUtil.h deleted file mode 100644 index 755ee2996e86..000000000000 --- a/velox/type/DoubleUtil.h +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace facebook::velox { - -/// A static class that holds helper functions for DOUBLE type. -class DoubleUtil { - public: - static const std::array kPowersOfTen; - -}; // DoubleUtil -} // namespace facebook::velox diff --git a/velox/type/DoubleUtil.cpp b/velox/type/FloatingPointUtil.cpp similarity index 98% rename from velox/type/DoubleUtil.cpp rename to velox/type/FloatingPointUtil.cpp index 7354e7369700..2e4fc9019b14 100644 --- a/velox/type/DoubleUtil.cpp +++ b/velox/type/FloatingPointUtil.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/type/DoubleUtil.h" +#include "velox/type/FloatingPointUtil.h" #include namespace facebook::velox { diff --git a/velox/type/FloatingPointUtil.h b/velox/type/FloatingPointUtil.h new file mode 100644 index 000000000000..39061249fe91 --- /dev/null +++ b/velox/type/FloatingPointUtil.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace facebook::velox { + +/// A utility class that holds custom comparator and hash functors for floating +/// point types. These are designed to ensure consistent NaN handling according +/// to the following rules: +/// - NaN == NaN returns true, even for NaNs with differing binary +/// representations. +/// - NaN is considered greater than infinity. +/// These can be passed to standard containers and functions like std::map and +/// std::sort, etc. +template +class FloatingPointUtil { + public: + static_assert( + std::is_floating_point::value, + "A floating point type is required."); + + struct NaNAwareEquals { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (std::isnan(lhs) && std::isnan(rhs)) { + return true; + } + return lhs == rhs; + } + }; + + struct NaNAwareLessThan { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (!std::isnan(lhs) && std::isnan(rhs)) { + return true; + } + return lhs < rhs; + } + }; + + struct NaNAwareHash { + std::size_t operator()(const FLOAT& val) const noexcept { + static const std::size_t kNanHash = + std::hash{}(std::numeric_limits::quiet_NaN()); + if (std::isnan(val)) { + return kNanHash; + } + return std::hash{}(val); + } + }; +}; + +/// A static class that holds helper functions for DOUBLE type. +class DoubleUtil { + public: + static const std::array kPowersOfTen; + +}; // DoubleUtil +} // namespace facebook::velox diff --git a/velox/type/Variant.cpp b/velox/type/Variant.cpp index 1f15cff4c6fb..a7756b276ac3 100644 --- a/velox/type/Variant.cpp +++ b/velox/type/Variant.cpp @@ -19,6 +19,7 @@ #include "folly/json.h" #include "velox/common/encode/Base64.h" #include "velox/type/DecimalUtil.h" +#include "velox/type/FloatingPointUtil.h" #include "velox/type/HugeInt.h" namespace facebook::velox { @@ -670,6 +671,42 @@ variant variant::create(const folly::dynamic& variantobj) { } } +template +bool variant::lessThan(const variant& a, const variant& b) const { + if (a.isNull() && !b.isNull()) { + return true; + } + if (a.isNull() || b.isNull()) { + return false; + } + if constexpr (KIND == TypeKind::REAL) { + return FloatingPointUtil::NaNAwareLessThan{}( + a.value(), b.value()); + } + if constexpr (KIND == TypeKind::DOUBLE) { + return FloatingPointUtil::NaNAwareLessThan{}( + a.value(), b.value()); + } + return a.value() < b.value(); +} + +template +bool variant::equals(const variant& a, const variant& b) const { + if (a.isNull() || b.isNull()) { + return false; + } + if constexpr (KIND == TypeKind::REAL) { + return FloatingPointUtil::NaNAwareEquals{}( + a.value(), b.value()); + } + if constexpr (KIND == TypeKind::DOUBLE) { + return FloatingPointUtil::NaNAwareEquals{}( + a.value(), b.value()); + } + // todo(youknowjack): centralize equality semantics + return a.value() == b.value(); +} + uint64_t variant::hash() const { uint64_t hash = 0; if (isNull()) { @@ -690,9 +727,10 @@ uint64_t variant::hash() const { case TypeKind::BOOLEAN: return folly::Hash{}(value()); case TypeKind::REAL: - return folly::Hash{}(value()); + return FloatingPointUtil::NaNAwareHash{}(value()); case TypeKind::DOUBLE: - return folly::Hash{}(value()); + return FloatingPointUtil::NaNAwareHash{}( + value()); case TypeKind::VARBINARY: return folly::Hash{}(value()); case TypeKind::VARCHAR: diff --git a/velox/type/Variant.h b/velox/type/Variant.h index 2021d7069b4e..919751c74a45 100644 --- a/velox/type/Variant.h +++ b/velox/type/Variant.h @@ -130,24 +130,10 @@ class variant { variant(TypeKind kind, void* ptr) : kind_{kind}, ptr_{ptr} {} template - bool lessThan(const variant& a, const variant& b) const { - if (a.isNull() && !b.isNull()) { - return true; - } - if (a.isNull() || b.isNull()) { - return false; - } - return a.value() < b.value(); - } + bool lessThan(const variant& a, const variant& b) const; template - bool equals(const variant& a, const variant& b) const { - if (a.isNull() || b.isNull()) { - return false; - } - // todo(youknowjack): centralize equality semantics - return a.value() == b.value(); - } + bool equals(const variant& a, const variant& b) const; template void typedDestroy() { diff --git a/velox/type/tests/VariantTest.cpp b/velox/type/tests/VariantTest.cpp index f195b4d42355..9882c76cb4a7 100644 --- a/velox/type/tests/VariantTest.cpp +++ b/velox/type/tests/VariantTest.cpp @@ -197,6 +197,33 @@ TEST(VariantTest, equalsWithEpsilonFloat) { ASSERT_FALSE(variant(sum1).equalsWithEpsilon(variant(sum3))); } +TEST(VariantTest, mapWithNaNKey) { + // Verify that map variants treat all NaN keys as equivalent and comparable + // (consider them the largest) with other values. + static const double KNan = std::numeric_limits::quiet_NaN(); + auto mapType = MAP(DOUBLE(), INTEGER()); + { + // NaN added at the start of insertions. + std::map mapVariant; + mapVariant.insert({variant(KNan), variant(1)}); + mapVariant.insert({variant(1.2), variant(2)}); + mapVariant.insert({variant(12.4), variant(3)}); + EXPECT_EQ( + "[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]", + variant::map(mapVariant).toJson(mapType)); + } + { + // NaN added in the middle of insertions. + std::map mapVariant; + mapVariant.insert({variant(1.2), variant(2)}); + mapVariant.insert({variant(KNan), variant(1)}); + mapVariant.insert({variant(12.4), variant(3)}); + EXPECT_EQ( + "[{\"key\":1.2,\"value\":2},{\"key\":12.4,\"value\":3},{\"key\":\"NaN\",\"value\":1}]", + variant::map(mapVariant).toJson(mapType)); + } +} + TEST(VariantTest, serialize) { // Null values. testSerDe(variant(TypeKind::BOOLEAN));