Skip to content

Commit

Permalink
Fix NaN handling for array_position (facebookincubator#9832)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#9832

Ensures that NaNs are considered as being equal to each other
and are identifiable.

Differential Revision: D57416206
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 16, 2024
1 parent f81cfc9 commit 90dfea2
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 4 deletions.
2 changes: 2 additions & 0 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ Array Functions
.. function:: array_position(x, element) -> bigint

Returns the position of the first occurrence of the ``element`` in array ``x`` (or 0 if not found).
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal.

.. function:: array_position(x, element, instance) -> bigint
:noindex:

If ``instance > 0``, returns the position of the ``instance``-th occurrence of the ``element`` in array ``x``. If ``instance < 0``, returns the position of the ``instance``-to-last occurrence of the ``element`` in array ``x``. If no matching element instance is found, 0 is returned.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal.

.. function:: array_remove(x, element) -> array

Expand Down
18 changes: 14 additions & 4 deletions velox/functions/prestosql/ArrayPosition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@

#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/vector/DecodedVector.h"

namespace facebook::velox::functions {
namespace {

template <typename T>
inline bool isPrimitiveEqual(const T& lhs, const T& rhs) {
if constexpr (std::is_floating_point_v<T>) {
return util::floating_point::NaNAwareEquals<T>{}(lhs, rhs);
} else {
return lhs == rhs;
}
}

// Find the index of the first match for primitive types.
template <
TypeKind kind,
Expand Down Expand Up @@ -76,7 +86,7 @@ void applyTypedFirstMatch(

int i;
for (i = 0; i < size; i++) {
if (rawElements[offset + i] == search) {
if (isPrimitiveEqual<T>(rawElements[offset + i], search)) {
flatResult.set(row, i + 1);
break;
}
Expand All @@ -99,7 +109,7 @@ void applyTypedFirstMatch(
int i;
for (i = 0; i < size; i++) {
if (!elementsDecoded.isNullAt(offset + i) &&
elementsDecoded.valueAt<T>(offset + i) == search) {
isPrimitiveEqual<T>(elementsDecoded.valueAt<T>(offset + i), search)) {
flatResult.set(row, i + 1);
break;
}
Expand Down Expand Up @@ -246,7 +256,7 @@ void applyTypedWithInstance(

int i;
for (i = startIndex; i != endIndex; i += step) {
if (rawElements[offset + i] == search) {
if (isPrimitiveEqual<T>(rawElements[offset + i], search)) {
if (--remaining == 0) {
flatResult.set(row, i + 1);
break;
Expand Down Expand Up @@ -278,7 +288,7 @@ void applyTypedWithInstance(
int i;
for (i = startIndex; i != endIndex; i += step) {
if (!elementsDecoded.isNullAt(offset + i) &&
elementsDecoded.valueAt<T>(offset + i) == search) {
isPrimitiveEqual<T>(elementsDecoded.valueAt<T>(offset + i), search)) {
--instance;
if (instance == 0) {
flatResult.set(row, i + 1);
Expand Down
95 changes: 95 additions & 0 deletions velox/functions/prestosql/tests/ArrayPositionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ class ArrayPositionTest : public FunctionBaseTest {
makeNullableFlatVector<int64_t>(expected));
}

template <typename T>
void testPosition(
const ArrayVectorPtr& arrayVector,
const std::optional<T>& search,
const std::vector<std::optional<int64_t>>& expected) {
evalExpr(
{arrayVector, makeConstant(search, arrayVector->size())},
"array_position(c0, c1)",
makeNullableFlatVector<int64_t>(expected));
}

void testPosition(
const ArrayVectorPtr& arrayVector,
const std::vector<int64_t>& search,
Expand Down Expand Up @@ -151,6 +162,20 @@ class ArrayPositionTest : public FunctionBaseTest {
makeNullableFlatVector<int64_t>(expected));
}

template <typename T>
void testPositionWithInstance(
const ArrayVectorPtr& array,
const std::optional<T>& search,
const int64_t instance,
const std::vector<std::optional<int64_t>>& expected) {
evalExpr(
{array,
makeConstant(search, array->size()),
makeConstant(instance, array->size())},
"array_position(c0, c1, c2)",
makeNullableFlatVector<int64_t>(expected));
}

template <typename T>
void testPositionWithInstanceNoNulls(
const std::vector<std::vector<T>>& array,
Expand Down Expand Up @@ -239,6 +264,71 @@ class ArrayPositionTest : public FunctionBaseTest {
"array_position(c0, c1, c2)",
dictExpectedVector);
}

// Verify that all NaNs are treated as equal and are identifiable in the input
// array.
template <typename T>
void testFloatingPointNaN() {
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSNaN = std::numeric_limits<T>::signaling_NaN();

// Test NaN in a simple array.
ArrayVectorPtr arrayVectorWithNulls = makeNullableArrayVector<T>({
{1, std::nullopt, kNaN, 4, kNaN, 5, 6, kNaN, 7},
{1, std::nullopt, kSNaN, 4, kNaN, 5, 6, kNaN, 7},
});
// This exercises the optimization for null free input.
ArrayVectorPtr arrayVectorWithoutNulls = makeArrayVector<T>({
{1, 3, kNaN, 4, kNaN, 5, 6, kNaN, 7},
{1, 3, kSNaN, 4, kNaN, 5, 6, kNaN, 7},
});

for (auto& arrayVectorPtr :
{arrayVectorWithNulls, arrayVectorWithoutNulls}) {
for (const T& nan : {kNaN, kSNaN}) {
testPosition<T>(arrayVectorPtr, nan, {3, 3});
testPositionWithInstance<T>(arrayVectorPtr, nan, 1, {3, 3});
testPositionWithInstance<T>(arrayVectorPtr, nan, 2, {5, 5});
testPositionWithInstance<T>(arrayVectorPtr, nan, -1, {8, 8});
}
}

// Test NaN withing a array of complex type.
std::vector<std::vector<std::optional<std::tuple<double, std::string>>>>
data = {
{{{1, "red"}}, {{kNaN, "blue"}}, {{3, "green"}}},
{{{kNaN, "blue"}}, std::nullopt, {{5, "green"}}},
{},
{std::nullopt},
{{{1, "yellow"}},
{{kNaN, "blue"}},
{{4, "green"}},
{{5, "purple"}}},
};

auto rowType = ROW({DOUBLE(), VARCHAR()});
auto arrayVector = makeArrayOfRowVector(data, rowType);
auto size = arrayVector->size();

auto testPositionOfRow =
[&](double n,
const char* color,
const std::vector<std::optional<int64_t>>& expected) {
auto expectedVector = makeNullableFlatVector<int64_t>(expected);
auto searchVector =
makeConstantRow(rowType, variant::row({n, color}), size);

evalExpr(
{arrayVector,
makeConstantRow(rowType, variant::row({n, color}), size)},
"array_position(c0, c1)",
makeNullableFlatVector<int64_t>(expected));
};

testPositionOfRow(1, "red", {1, 0, 0, 0, 0});
testPositionOfRow(kNaN, "blue", {2, 1, 0, 0, 2});
testPositionOfRow(kSNaN, "blue", {2, 1, 0, 0, 2});
}
};

TEST_F(ArrayPositionTest, integer) {
Expand Down Expand Up @@ -882,4 +972,9 @@ TEST_F(ArrayPositionTest, dictionaryEncodingElements) {
testPosition(arrayVector, {1, 2, 3, 4}, {5, 2}, -1);
}

TEST_F(ArrayPositionTest, floatNaN) {
testFloatingPointNaN<float>();
testFloatingPointNaN<double>();
}

} // namespace

0 comments on commit 90dfea2

Please sign in to comment.