diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 86e7e2595514..cc425dbad404 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -149,9 +149,11 @@ Array Functions .. function:: array_remove(x, element) -> array Remove all elements that equal ``element`` from array ``x``. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. SELECT array_remove(ARRAY [1, 2, 3], 3); -- [1, 2] SELECT array_remove(ARRAY [2, 1, NULL], 1); -- [2, NULL] + SELECT array_remove(ARRAY [2.1, 1.1, nan()], nan()); -- [2.1, 1.1] .. function:: array_sort(array(E)) -> array(E) @@ -231,8 +233,10 @@ Array Functions Returns true if the array ``x`` contains the ``element``. When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls - and these need to be compared to produce a result. :: + and these need to be compared to produce a result. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: + SELECT contains(ARRAY [2.1, 1.1, nan()], nan()); -- true. SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false. SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null diff --git a/velox/functions/prestosql/ArrayContains.cpp b/velox/functions/prestosql/ArrayContains.cpp index bc5f8e912919..8e13cbbe94a1 100644 --- a/velox/functions/prestosql/ArrayContains.cpp +++ b/velox/functions/prestosql/ArrayContains.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "velox/expression/VectorFunction.h" +#include "velox/type/FloatingPointUtil.h" #include "velox/vector/DecodedVector.h" namespace facebook::velox::functions { @@ -47,9 +48,17 @@ void applyTyped( auto offset = rawOffsets[indices[row]]; for (auto i = 0; i < size; i++) { - if (rawElements[offset + i] == search) { - flatResult.set(row, true); - return; + if constexpr (std::is_floating_point_v) { + if (util::floating_point::NaNAwareEquals{}( + rawElements[offset + i], search)) { + flatResult.set(row, true); + return; + } + } else { + if (rawElements[offset + i] == search) { + flatResult.set(row, true); + return; + } } } diff --git a/velox/functions/prestosql/ArrayFunctions.h b/velox/functions/prestosql/ArrayFunctions.h index 363560d75fd8..e53c25411819 100644 --- a/velox/functions/prestosql/ArrayFunctions.h +++ b/velox/functions/prestosql/ArrayFunctions.h @@ -871,8 +871,15 @@ struct ArrayRemoveFunction { void call(Out& out, const In& inputArray, E element) { for (const auto& item : inputArray) { if (item.has_value()) { - if (element != item.value()) { - out.push_back(item.value()); + if constexpr (std::is_floating_point_v) { + if (!util::floating_point::NaNAwareEquals{}( + element, item.value())) { + out.push_back(item.value()); + } + } else { + if (element != item.value()) { + out.push_back(item.value()); + } } } else { out.add_null(); diff --git a/velox/functions/prestosql/tests/ArrayContainsTest.cpp b/velox/functions/prestosql/tests/ArrayContainsTest.cpp index 0901be99f537..fd1acc220d17 100644 --- a/velox/functions/prestosql/tests/ArrayContainsTest.cpp +++ b/velox/functions/prestosql/tests/ArrayContainsTest.cpp @@ -53,6 +53,16 @@ class ArrayContainsTest : public FunctionBaseTest { assertEqualVectors(makeNullableFlatVector(expected), result); }; + + template + void testFloatingPointNaNs() { + static const T kQuietNaN = std::numeric_limits::quiet_NaN(); + static const T kSignalingNaN = std::numeric_limits::signaling_NaN(); + auto arrayVector = makeArrayVector( + {{1, 2, 3, 4}, {3, 4, kQuietNaN}, {5, 6, 7, 8, kSignalingNaN}}); + + testContains(arrayVector, kQuietNaN, {false, true, true}); + } }; TEST_F(ArrayContainsTest, integerNoNulls) { @@ -423,4 +433,9 @@ TEST_F(ArrayContainsTest, rowCheckNulls) { // (3, null) = (3, null) is true in $internal$contains. ASSERT_TRUE(contains({3, std::nullopt}, true)); } + +TEST_F(ArrayContainsTest, floatNaNs) { + testFloatingPointNaNs(); + testFloatingPointNaNs(); +} } // namespace diff --git a/velox/functions/prestosql/tests/ArrayRemoveTest.cpp b/velox/functions/prestosql/tests/ArrayRemoveTest.cpp index 4d11bcc90648..8a483087febb 100644 --- a/velox/functions/prestosql/tests/ArrayRemoveTest.cpp +++ b/velox/functions/prestosql/tests/ArrayRemoveTest.cpp @@ -40,6 +40,24 @@ class ArrayRemoveTest : public FunctionBaseTest { VELOX_ASSERT_THROW( evaluate(expression, makeRowVector(input)), expectedError); } + + template + void testFloats() { + static const T kQuietNaN = std::numeric_limits::quiet_NaN(); + static const T kSignalingNaN = std::numeric_limits::signaling_NaN(); + const auto arrayVector = makeNullableArrayVector( + {{1, std::nullopt, 2, 3, std::nullopt, 4}, + {3, 4, 5, kQuietNaN, 3, 4, kQuietNaN}, + {kSignalingNaN, 8, 9}}); + const auto elementVector = makeFlatVector({3, kQuietNaN, kQuietNaN}); + const auto expected = makeNullableArrayVector({ + {1, std::nullopt, 2, std::nullopt, 4}, + {3, 4, 5, 3, 4}, + {8, 9}, + }); + testExpression( + "array_remove(c0, c1)", {arrayVector, elementVector}, expected); + } }; //// Remove simple-type elements from array. @@ -60,6 +78,12 @@ TEST_F(ArrayRemoveTest, arrayWithSimpleTypes) { "array_remove(c0, c1)", {arrayVector, elementVector}, expected); } +//// Remove simple-type elements from array. +TEST_F(ArrayRemoveTest, arrayWithFloatTypes) { + testFloats(); + testFloats(); +} + //// Remove simple-type elements from array. TEST_F(ArrayRemoveTest, arrayWithString) { const auto arrayVector = makeNullableArrayVector(