Skip to content


Fix NaN handling for map subscript and add test for map() (facebookin…
Browse files Browse the repository at this point in the history


Ensures that map subscript identifies NaN as a key where NaNs with any binary
representation are considered equal.

Differential Revision: D57634535
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 21, 2024
1 parent 47c6b25 commit 42cca87
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 11 deletions.
4 changes: 3 additions & 1 deletion velox/docs/functions/presto/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ Map Functions
.. function:: map(array(K), array(V)) -> map(K,V)

Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls. ::
Returns a map created using the given key/value arrays. Keys are not allowed to be null or to contain nulls.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT map(ARRAY[1,3], ARRAY[2,4]); -- {1 -> 2, 3 -> 4}

Expand Down Expand Up @@ -147,6 +148,7 @@ Map Functions

Returns value for given ``key``. Return null if the key is not contained in the map.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal and can be used as keys.
Corresponds to SQL subscript operator [].

SELECT name_to_age_map['Bob'] AS bob_age;
Expand Down
12 changes: 11 additions & 1 deletion velox/functions/lib/SubscriptUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ namespace facebook::velox::functions {

namespace {

template <typename T>
inline bool isPrimitiveEqual(const T& lhs, const T& rhs) {
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareEquals<T>{}(lhs, rhs);
} else {
return lhs == rhs;

template <TypeKind Kind>
struct SimpleType {
using type = typename TypeTraits<Kind>::NativeType;
Expand Down Expand Up @@ -128,7 +137,8 @@ VectorPtr applyMapTyped(
} else {
// Search map without caching.
for (size_t offset = offsetStart; offset < offsetEnd; ++offset) {
if (decodedMapKeys->valueAt<TKey>(offset) == searchKey) {
if (isPrimitiveEqual<TKey>(
decodedMapKeys->valueAt<TKey>(offset), searchKey)) {
rawIndices[row] = offset;
found = true;
Expand Down
9 changes: 3 additions & 6 deletions velox/functions/lib/SubscriptUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/expression/VectorReaders.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/type/Type.h"
#include "velox/vector/BaseVector.h"
#include "velox/vector/ComplexVector.h"
Expand Down Expand Up @@ -77,12 +78,8 @@ class LookupTable : public LookupTableBase {
using inner_allocator_t =
memory::StlAllocator<std::pair<key_t const, vector_size_t>>;

using inner_map_t = folly::F14FastMap<
using inner_map_t = typename util::floating_point::
HashMapNaNAwareTypeTraits<key_t, vector_size_t, inner_allocator_t>::Type;

using outer_allocator_t =
memory::StlAllocator<std::pair<vector_size_t const, inner_map_t>>;
Expand Down
96 changes: 96 additions & 0 deletions velox/functions/prestosql/tests/ElementAtTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,93 @@ class ElementAtTest : public FunctionBaseTest {
"{10: 10, 11: 11, 12: 12}",

template <typename T>
void testFloatingPointCornerCases() {
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();

auto values = makeFlatVector<int32_t>({1, 2, 3, 4, 5});
auto expected = makeConstant<int32_t>(3, 1);

auto elementAt = [&](auto map, auto search) {
return evaluate("element_at(C0, C1)", makeRowVector({map, search}));

// Case 1: Verify NaNs identified even with different binary
// representations.
auto keysIdenticalNaNs = makeFlatVector<T>({1, 2, kNaN, 4, 5});
auto mapVector = makeMapVector({0}, keysIdenticalNaNs, values);
expected, elementAt(mapVector, makeConstant<T>(kNaN, 1)));
expected, elementAt(mapVector, makeConstant<T>(kSNaN, 1)));

// Case 2: Verify for equality of +0.0 and -0.0.
auto keysDifferentZeros = makeFlatVector<T>({1, 2, -0.0, 4, 5});
mapVector = makeMapVector({0}, keysDifferentZeros, values);
expected, elementAt(mapVector, makeConstant<T>(0.0, 1)));
expected, elementAt(mapVector, makeConstant<T>(-0.0, 1)));

// Case 3: Verify NaNs are identified when nested inside complex type keys
auto rowKeys = makeRowVector(
{makeFlatVector<T>({1, 2, kNaN, 4, 5, 6}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6})});
auto mapOfRowKeys = makeMapVector(
{0, 3}, rowKeys, makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}));
auto elementValue = makeRowVector(
{makeFlatVector<T>({kSNaN, 1}), makeFlatVector<int32_t>({3, 1})});
auto element = BaseVector::wrapInConstant(2, 0, elementValue);
auto expected = makeNullableFlatVector<int32_t>({3, std::nullopt});
auto result = evaluate(
"element_at(C0, C1)", makeRowVector({mapOfRowKeys, element}));
test::assertEqualVectors(expected, result);
// case 4: Verify NaNs are identified when employing caching.
exec::ExprSet exprSet({}, &execCtx_);
auto inputs = makeRowVector({});
exec::EvalCtx evalCtx(&execCtx_, &exprSet, inputs.get());

SelectivityVector rows(1);
auto inputMap = makeMapVector({0}, keysIdenticalNaNs, values);

auto keys = makeFlatVector<T>(std::vector<T>({kSNaN}));
std::vector<VectorPtr> args = {inputMap, keys};

facebook::velox::functions::MapSubscript mapSubscriptWithCaching(true);

auto checkStatus = [&](bool cachingEnabled,
bool materializedMapIsNull,
const VectorPtr& firtSeen) {
EXPECT_EQ(cachingEnabled, mapSubscriptWithCaching.cachingEnabled());
EXPECT_EQ(firtSeen, mapSubscriptWithCaching.firstSeenMap());
nullptr == mapSubscriptWithCaching.lookupTable());

// Initial state.
checkStatus(true, true, nullptr);

auto result1 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
// Nothing has been materialized yet since the input is seen only once.
checkStatus(true, true, args[0]);

auto result2 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
checkStatus(true, false, args[0]);

auto result3 = mapSubscriptWithCaching.applyMap(rows, args, evalCtx);
checkStatus(true, false, args[0]);

// all the result should be the same.
expected = makeConstant<int32_t>(3, 1);
test::assertEqualVectors(expected, result2);
test::assertEqualVectors(result1, result2);
test::assertEqualVectors(result2, result3);

template <>
Expand Down Expand Up @@ -1086,3 +1173,12 @@ TEST_F(ElementAtTest, testCachingOptimzation) {
test::assertEqualVectors(result, result1);

TEST_F(ElementAtTest, floatingPointCornerCases) {
// Verify that different code paths (keys of simple types, complex types and
// optimized caching) correctly identify NaNs and treat all NaNs with
// different binary representations as equal. Also verifies that -/+ 0.0 are
// considered equal.
65 changes: 64 additions & 1 deletion velox/functions/prestosql/tests/MapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.

#include <functional>
#include <optional>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"
Expand All @@ -25,7 +26,64 @@ using namespace facebook::velox::functions::test;

namespace {

class MapTest : public FunctionBaseTest {};
class MapTest : public FunctionBaseTest {
template <typename T>
void testFloatingPointCornerCases() {
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();
// Case 1: Check for duplicate NaNs with the same binary representation.
VectorPtr keysIdenticalNaNs =
makeNullableArrayVector<T>({{1, 2, kNaN, 4, 5, kNaN}});
// Case 2: Check for duplicate NaNs with different binary representation.
VectorPtr keysDifferentNaNs =
makeNullableArrayVector<T>({{1, 2, kNaN, 4, 5, kSNaN}});
// Case 3: Check for duplicate NaNs when the keys vector is a constant. This
// is to ensure the code path for constant keys is exercised.
VectorPtr keysConstant =
BaseVector::wrapInConstant(1, 0, keysDifferentNaNs);
// Case 4: Check for duplicate NaNs when the keys vector wrapped in a
// dictionary.
VectorPtr keysInDictionary =
wrapInDictionary(makeIndices(1, std::identity{}), keysDifferentNaNs);
// Case 5: Check for equality of +0.0 and -0.0.
VectorPtr keysDifferentZeros =
makeNullableArrayVector<T>({{1, 2, -0.0, 4, 5, 0.0}});
auto values = makeNullableArrayVector<int32_t>({{1, 2, 3, 4, 5, 6}});

auto checkDuplicate = [&](VectorPtr& keys, std::string expectedError) {
evaluate("map(c0, c1)", makeRowVector({keys, values})),

evaluate("try(map(c0, c1))", makeRowVector({keys, values})));

// Trying the map version with allowing duplicates.
ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values})));

keysIdenticalNaNs, "Duplicate map keys (NaN) are not allowed");
keysDifferentNaNs, "Duplicate map keys (NaN) are not allowed");
checkDuplicate(keysConstant, "Duplicate map keys (NaN) are not allowed");
keysInDictionary, "Duplicate map keys (NaN) are not allowed");
keysDifferentZeros, "Duplicate map keys (0) are not allowed");

// Case 6: Check for duplicate NaNs nested inside a complex key.
VectorPtr arrayOfRows = makeArrayVector(
{makeFlatVector<T>({1, 2, kNaN, 4, 5, kSNaN}),
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 3})}));
arrayOfRows, "Duplicate map keys ({NaN, 3}) are not allowed");

TEST_F(MapTest, noNulls) {
auto size = 1'000;
Expand Down Expand Up @@ -170,6 +228,11 @@ TEST_F(MapTest, duplicateKeys) {
ASSERT_NO_THROW(evaluate("map2(c0, c1)", makeRowVector({keys, values})));

TEST_F(MapTest, floatingPointCornerCases) {

TEST_F(MapTest, fewerValuesThanKeys) {
auto size = 1'000;

Expand Down
55 changes: 53 additions & 2 deletions velox/type/FloatingPointUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cmath>
#include <vector>

#include <folly/container/F14Map.h>
#include <folly/container/F14Set.h>

namespace facebook::velox {
Expand Down Expand Up @@ -83,8 +84,9 @@ struct NaNAwareHash {

// Utility struct to provide a clean way of defining a hash set type using
// folly::F14FastSet with overrides for floating point types.
// Utility struct to provide a clean way of defining a hash set and map type
// using folly::F14FastSet and folly::F14FastMap respectively with overrides for
// floating point types.
template <typename Key>
class HashSetNaNAware : public folly::F14FastSet<Key> {};

Expand All @@ -97,6 +99,55 @@ template <>
class HashSetNaNAware<double>
: public folly::
F14FastSet<double, NaNAwareHash<double>, NaNAwareEquals<double>> {};

template <
typename Key,
typename Mapped,
typename Alloc = folly::f14::DefaultAlloc<std::pair<Key const, Mapped>>>
struct HashMapNaNAwareTypeTraits {
using Type = folly::F14FastMap<

template <typename Mapped, typename Alloc>
struct HashMapNaNAwareTypeTraits<float, Mapped, Alloc> {
using Type = folly::F14FastMap<

template <typename Mapped, typename Alloc>
struct HashMapNaNAwareTypeTraits<double, Mapped, Alloc> {
using Type = folly::F14FastMap<

/* template <typename Mapped, typename Alloc>
class HashMapNaNAware<float, Mapped, Alloc> : public folly::F14FastMap<
Alloc> {};
template <typename Mapped, typename Alloc>
class HashMapNaNAware<double, Mapped, Alloc> : public folly::F14FastMap<
Alloc> {}; */
} // namespace util::floating_point

/// A static class that holds helper functions for DOUBLE type.
Expand Down

0 comments on commit 42cca87

Please sign in to comment.