Skip to content

Commit

Permalink
Fix NaN handling for array_remove and contains UDFs
Browse files Browse the repository at this point in the history
Summary: Ensures that NaNs are considered as being equal to each other.

Differential Revision: D57342173
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed May 14, 2024
1 parent e609fa5 commit 1056132
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
6 changes: 5 additions & 1 deletion velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,11 @@ Array Functions
.. function:: array_remove(x, element) -> array

Remove all elements that equal ``element`` from array ``x``.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal.

SELECT array_remove(ARRAY [1, 2, 3], 3); -- [1, 2]
SELECT array_remove(ARRAY [2, 1, NULL], 1); -- [2, NULL]
SELECT array_remove(ARRAY [2.1, 1.1, nan()], nan()); -- [2.1, 1.1]

.. function:: array_sort(array(E)) -> array(E)

Expand Down Expand Up @@ -231,8 +233,10 @@ Array Functions

Returns true if the array ``x`` contains the ``element``.
When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls
and these need to be compared to produce a result. ::
and these need to be compared to produce a result.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT contains(ARRAY [2.1, 1.1, nan()], nan()); -- true.
SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false.
SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null
SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null
Expand Down
15 changes: 12 additions & 3 deletions velox/functions/prestosql/ArrayContains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "velox/expression/VectorFunction.h"
#include "velox/type/FloatingPointUtil.h"
#include "velox/vector/DecodedVector.h"

namespace facebook::velox::functions {
Expand Down Expand Up @@ -47,9 +48,17 @@ void applyTyped(
auto offset = rawOffsets[indices[row]];

for (auto i = 0; i < size; i++) {
if (rawElements[offset + i] == search) {
flatResult.set(row, true);
return;
if constexpr (std::is_floating_point_v<T>) {
if (util::floating_point::NaNAwareEquals<T>{}(
rawElements[offset + i], search)) {
flatResult.set(row, true);
return;
}
} else {
if (rawElements[offset + i] == search) {
flatResult.set(row, true);
return;
}
}
}

Expand Down
11 changes: 9 additions & 2 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -871,8 +871,15 @@ struct ArrayRemoveFunction {
void call(Out& out, const In& inputArray, E element) {
for (const auto& item : inputArray) {
if (item.has_value()) {
if (element != item.value()) {
out.push_back(item.value());
if constexpr (std::is_floating_point_v<E>) {
if (!util::floating_point::NaNAwareEquals<E>{}(
element, item.value())) {
out.push_back(item.value());
}
} else {
if (element != item.value()) {
out.push_back(item.value());
}
}
} else {
out.add_null();
Expand Down
15 changes: 15 additions & 0 deletions velox/functions/prestosql/tests/ArrayContainsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ class ArrayContainsTest : public FunctionBaseTest {

assertEqualVectors(makeNullableFlatVector<bool>(expected), result);
};

template <typename T>
void testFloatingPointNaNs() {
static const T kQuietNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSignalingNaN = std::numeric_limits<T>::signaling_NaN();
auto arrayVector = makeArrayVector<T>(
{{1, 2, 3, 4}, {3, 4, kQuietNaN}, {5, 6, 7, 8, kSignalingNaN}});

testContains(arrayVector, kQuietNaN, {false, true, true});
}
};

TEST_F(ArrayContainsTest, integerNoNulls) {
Expand Down Expand Up @@ -423,4 +433,9 @@ TEST_F(ArrayContainsTest, rowCheckNulls) {
// (3, null) = (3, null) is true in $internal$contains.
ASSERT_TRUE(contains({3, std::nullopt}, true));
}

TEST_F(ArrayContainsTest, floatNaNs) {
testFloatingPointNaNs<float>();
testFloatingPointNaNs<double>();
}
} // namespace
24 changes: 24 additions & 0 deletions velox/functions/prestosql/tests/ArrayRemoveTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ class ArrayRemoveTest : public FunctionBaseTest {
VELOX_ASSERT_THROW(
evaluate(expression, makeRowVector(input)), expectedError);
}

template <typename T>
void testFloats() {
static const T kQuietNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSignalingNaN = std::numeric_limits<T>::signaling_NaN();
const auto arrayVector = makeNullableArrayVector<T>(
{{1, std::nullopt, 2, 3, std::nullopt, 4},
{3, 4, 5, kQuietNaN, 3, 4, kQuietNaN},
{kSignalingNaN, 8, 9}});
const auto elementVector = makeFlatVector<T>({3, kQuietNaN, kQuietNaN});
const auto expected = makeNullableArrayVector<T>({
{1, std::nullopt, 2, std::nullopt, 4},
{3, 4, 5, 3, 4},
{8, 9},
});
testExpression(
"array_remove(c0, c1)", {arrayVector, elementVector}, expected);
}
};

//// Remove simple-type elements from array.
Expand All @@ -60,6 +78,12 @@ TEST_F(ArrayRemoveTest, arrayWithSimpleTypes) {
"array_remove(c0, c1)", {arrayVector, elementVector}, expected);
}

//// Remove simple-type elements from array.
TEST_F(ArrayRemoveTest, arrayWithFloatTypes) {
testFloats<float>();
testFloats<double>();
}

//// Remove simple-type elements from array.
TEST_F(ArrayRemoveTest, arrayWithString) {
const auto arrayVector = makeNullableArrayVector<std::string>(
Expand Down

0 comments on commit 1056132

Please sign in to comment.