diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 05dea9101c176..86e7e2595514a 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -35,7 +35,8 @@ Array Functions .. function:: array_distinct(array(E)) -> array(E) - Remove duplicate values from the input array. :: + Remove duplicate values from the input array. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: SELECT array_distinct(ARRAY [1, 2, 3]); -- [1, 2, 3] SELECT array_distinct(ARRAY [1, 2, 1]); -- [1, 2] @@ -50,7 +51,8 @@ Array Functions .. function:: array_except(array(E) x, array(E) y) -> array(E) - Returns an array of the elements in array ``x`` but not in array ``y``, without duplicates. :: + Returns an array of the elements in array ``x`` but not in array ``y``, without duplicates. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: SELECT array_except(ARRAY [1, 2, 3], ARRAY [4, 5, 6]); -- [1, 2, 3] SELECT array_except(ARRAY [1, 2, 3], ARRAY [1, 2]); -- [3] @@ -77,7 +79,8 @@ Array Functions .. function:: array_intersect(array(E) x, array(E) y) -> array(E) - Returns an array of the elements in the intersection of array ``x`` and array ``y``, without duplicates. :: + Returns an array of the elements in the intersection of array ``x`` and array ``y``, without duplicates. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. :: SELECT array_intersect(ARRAY [1, 2, 3], ARRAY[4, 5, 6]); -- [] SELECT array_intersect(ARRAY [1, 2, 2], ARRAY[1, 1, 2]); -- [1, 2] @@ -127,6 +130,12 @@ Array Functions Tests if arrays ``x`` and ``y`` have any non-null elements in common. Returns null if there are no non-null elements in common but either array contains null. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. + +.. function:: arrays_union(x, y) -> array + + Returns an array of the elements in the union of x and y, without duplicates. + For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. .. function:: array_position(x, element) -> bigint @@ -153,6 +162,7 @@ Array Functions SELECT array_sort(ARRAY [1, 2, 3]); -- [1, 2, 3] SELECT array_sort(ARRAY [3, 2, 1]); -- [1, 2, 3] + SELECT array_sort(ARRAY [infinity(), -1.1, nan(), 1.1, -Infinity(), 0])); -- [-Infinity, -1.1, 0, 1.1, Infinity, NaN] SELECT array_sort(ARRAY [2, 1, NULL]; -- [1, 2, NULL] SELECT array_sort(ARRAY [NULL, 1, NULL]); -- [1, NULL, NULL] SELECT array_sort(ARRAY [NULL, 2, 1]); -- [1, 2, NULL] diff --git a/velox/functions/prestosql/ArrayDistinct.cpp b/velox/functions/prestosql/ArrayDistinct.cpp index 86ff2867824d4..ba610f1fad4e9 100644 --- a/velox/functions/prestosql/ArrayDistinct.cpp +++ b/velox/functions/prestosql/ArrayDistinct.cpp @@ -14,12 +14,12 @@ * limitations under the License. */ -#include #include "velox/expression/EvalCtx.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" #include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { namespace { @@ -96,7 +96,9 @@ class ArrayDistinctFunction : public exec::VectorFunction { auto* rawNewOffsets = newOffsets->asMutable(); // Process the rows: store unique values in the hash table. - folly::F14FastSet uniqueSet; + using HashSetType = + typename util::floating_point::HashSetTypeTraits::Type; + HashSetType uniqueSet; rows.applyToSelected([&](vector_size_t row) { auto size = arrayVector->sizeAt(row); diff --git a/velox/functions/prestosql/ArrayFunctions.h b/velox/functions/prestosql/ArrayFunctions.h index 9bb3321a85303..363560d75fd81 100644 --- a/velox/functions/prestosql/ArrayFunctions.h +++ b/velox/functions/prestosql/ArrayFunctions.h @@ -34,14 +34,14 @@ struct ArrayMinMaxFunction { template void update(T& currentValue, const T& candidateValue) { if constexpr (std::is_same_v || std::is_same_v) { - using facebook::velox::util::floating_point::NaNAwareGreaterThan; - using facebook::velox::util::floating_point::NaNAwareLessThan; if constexpr (isMax) { - if (NaNAwareGreaterThan{}(candidateValue, currentValue)) { + if (util::floating_point::NaNAwareGreaterThan{}( + candidateValue, currentValue)) { currentValue = candidateValue; } } else { - if (NaNAwareLessThan{}(candidateValue, currentValue)) { + if (util::floating_point::NaNAwareLessThan{}( + candidateValue, currentValue)) { currentValue = candidateValue; } } @@ -836,7 +836,9 @@ struct ArrayUnionFunction { template void call(Out& out, const In& inputArray1, const In& inputArray2) { - folly::F14FastSet elementSet; + using HashSetType = typename util::floating_point::HashSetTypeTraits< + typename In::element_t>::Type; + HashSetType elementSet; bool nullAdded = false; auto addItems = [&](auto& inputArray) { for (const auto& item : inputArray) { diff --git a/velox/functions/prestosql/ArrayIntersectExcept.cpp b/velox/functions/prestosql/ArrayIntersectExcept.cpp index 1336ff2a0fe3b..143991a4f97ab 100644 --- a/velox/functions/prestosql/ArrayIntersectExcept.cpp +++ b/velox/functions/prestosql/ArrayIntersectExcept.cpp @@ -16,6 +16,7 @@ #include "velox/expression/VectorFunction.h" #include "velox/functions/lib/LambdaFunctionUtil.h" #include "velox/functions/lib/RowsTranslationUtil.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { namespace { @@ -31,7 +32,8 @@ struct SetWithNull { hasNull = false; } - folly::F14FastSet set; + using HashSetType = typename util::floating_point::HashSetTypeTraits::Type; + HashSetType set; bool hasNull{false}; static constexpr vector_size_t kInitialSetSize{128}; }; diff --git a/velox/functions/prestosql/ArraySort.cpp b/velox/functions/prestosql/ArraySort.cpp index aa8a7f4cd43e5..70d3ac152a105 100644 --- a/velox/functions/prestosql/ArraySort.cpp +++ b/velox/functions/prestosql/ArraySort.cpp @@ -22,6 +22,7 @@ #include "velox/functions/lib/LambdaFunctionUtil.h" #include "velox/functions/lib/RowsTranslationUtil.h" #include "velox/functions/prestosql/SimpleComparisonMatcher.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { namespace { @@ -172,6 +173,19 @@ void applyScalarType( bits::fillBits(rawBits, startRow, startRow + numOneBits, true); bits::fillBits(rawBits, endZeroRow, endRow, false); } + } else if constexpr (kind == TypeKind::REAL || kind == TypeKind::DOUBLE) { + T* resultRawValues = flatResults->mutableRawValues(); + if (ascending) { + std::sort( + resultRawValues + startRow, + resultRawValues + endRow, + util::floating_point::NaNAwareLessThan()); + } else { + std::sort( + resultRawValues + startRow, + resultRawValues + endRow, + util::floating_point::NaNAwareGreaterThan()); + } } else { T* resultRawValues = flatResults->mutableRawValues(); if (ascending) { diff --git a/velox/functions/prestosql/tests/ArrayDistinctTest.cpp b/velox/functions/prestosql/tests/ArrayDistinctTest.cpp index 9cebeac3b85cb..516185aa97b52 100644 --- a/velox/functions/prestosql/tests/ArrayDistinctTest.cpp +++ b/velox/functions/prestosql/tests/ArrayDistinctTest.cpp @@ -105,6 +105,8 @@ class ArrayDistinctTest : public FunctionBaseTest { {0.0, -10.0}, {std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::signaling_NaN()}, {std::numeric_limits::signaling_NaN(), std::numeric_limits::signaling_NaN()}, {std::numeric_limits::lowest(), std::numeric_limits::lowest()}, @@ -134,10 +136,10 @@ class ArrayDistinctTest : public FunctionBaseTest { {0.0}, {0.0, 10.0}, {0.0, -10.0}, - {std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}, - {std::numeric_limits::signaling_NaN(), - std::numeric_limits::signaling_NaN()}, + {std::numeric_limits::quiet_NaN()}, + // quiet NaN and signaling NaN are treated equal + {std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::signaling_NaN()}, {std::numeric_limits::lowest()}, {std::nullopt}, {1.0001, -2.0, 3.03, std::nullopt, 4.00004}, diff --git a/velox/functions/prestosql/tests/ArrayExceptTest.cpp b/velox/functions/prestosql/tests/ArrayExceptTest.cpp index 7004caeb341ac..aef7de5d2e161 100644 --- a/velox/functions/prestosql/tests/ArrayExceptTest.cpp +++ b/velox/functions/prestosql/tests/ArrayExceptTest.cpp @@ -107,6 +107,8 @@ class ArrayExceptTest : public FunctionBaseTest { std::numeric_limits::infinity(), std::numeric_limits::max()}, {std::numeric_limits::quiet_NaN(), 9.0009}, + {std::numeric_limits::quiet_NaN(), 9.0009}, + {std::numeric_limits::quiet_NaN(), 9.0009}, }); auto array2 = makeNullableArrayVector({ {1.0, -2.0, 4.0}, @@ -114,24 +116,29 @@ class ArrayExceptTest : public FunctionBaseTest { {1.0001, -2.02, std::numeric_limits::max(), 8.00099}, {9.0009, std::numeric_limits::infinity()}, {9.0009, std::numeric_limits::quiet_NaN()}, - }); - - auto expected = makeNullableArrayVector({ - {1.0001, 3.03, std::nullopt, 4.00004}, - {2.02, 1}, - {8.0001, std::nullopt}, - {std::numeric_limits::max()}, {std::numeric_limits::quiet_NaN()}, + // quiet NaN and signaling NaN are treated equal + {std::numeric_limits::signaling_NaN()}, }); + + auto expected = makeNullableArrayVector( + {{1.0001, 3.03, std::nullopt, 4.00004}, + {2.02, 1}, + {8.0001, std::nullopt}, + {std::numeric_limits::max()}, + {}, + {9.0009}, + {9.0009}}); testExpr(expected, "array_except(C0, C1)", {array1, array2}); - expected = makeNullableArrayVector({ - {1.0, 4.0}, - {2.0199, 1.000001}, - {1.0001, -2.02, 8.00099}, - {}, - {std::numeric_limits::quiet_NaN()}, - }); + expected = makeNullableArrayVector( + {{1.0, 4.0}, + {2.0199, 1.000001}, + {1.0001, -2.02, 8.00099}, + {}, + {}, + {}, + {}}); testExpr(expected, "array_except(C1, C0)", {array1, array2}); } }; diff --git a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp index 9e466a6ccfbfd..860a27793eb7e 100644 --- a/velox/functions/prestosql/tests/ArrayIntersectTest.cpp +++ b/velox/functions/prestosql/tests/ArrayIntersectTest.cpp @@ -106,20 +106,24 @@ class ArrayIntersectTest : public FunctionBaseTest { std::numeric_limits::infinity(), std::numeric_limits::max()}, {std::numeric_limits::quiet_NaN(), 9.0009}, + {std::numeric_limits::quiet_NaN(), 9.0009}, }); auto array2 = makeNullableArrayVector({ {1.0, -2.0, 4.0}, {std::numeric_limits::min(), 2.0199, -2.001, 1.000001}, {1.0001, -2.02, std::numeric_limits::max(), 8.00099}, {9.0009, std::numeric_limits::infinity()}, - {9.0009, std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN()}, + // quiet NaN and signaling NaN are treated equal + {std::numeric_limits::signaling_NaN(), 9.0009}, }); auto expected = makeNullableArrayVector({ {-2.0}, {std::numeric_limits::min(), -2.001}, {std::numeric_limits::max()}, {9.0009, std::numeric_limits::infinity()}, - {9.0009}, + {std::numeric_limits::quiet_NaN()}, + {std::numeric_limits::quiet_NaN(), 9.0009}, }); testExpr(expected, "array_intersect(C0, C1)", {array1, array2}); diff --git a/velox/functions/prestosql/tests/ArraySortTest.cpp b/velox/functions/prestosql/tests/ArraySortTest.cpp index e164a7b19aee1..52bd8162769b7 100644 --- a/velox/functions/prestosql/tests/ArraySortTest.cpp +++ b/velox/functions/prestosql/tests/ArraySortTest.cpp @@ -224,6 +224,29 @@ class ArraySortTest : public FunctionBaseTest, } } + template + void testFloatingPoint() { + // Verify that NaNs are treated as greater than infinity + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kInfinity = std::numeric_limits::infinity(); + static const T kNegativeInfinity = -1 * std::numeric_limits::infinity(); + + auto input = makeRowVector({makeNullableArrayVector( + {{kInfinity, -1, kNaN, 1, kNegativeInfinity, kNaN, 0}})}); + + { + auto expected = makeNullableArrayVector( + {{kNegativeInfinity, -1, 0, 1, kInfinity, kNaN, kNaN}}); + assertEqualVectors(expected, evaluate("try(array_sort(c0))", input)); + } + + { + auto expected = makeNullableArrayVector( + {{kNaN, kNaN, kInfinity, 1, 0, -1, kNegativeInfinity}}); + assertEqualVectors(expected, evaluate("try(array_sort_desc(c0))", input)); + } + } + // Specify the number of values per each data vector in 'dataVectorsByType_'. const int numValues_; std::unordered_map dataVectorsByType_; @@ -680,6 +703,11 @@ TEST_F(ArraySortTest, failOnRowNullCompare) { } } +TEST_F(ArraySortTest, floatingPointExtremes) { + testFloatingPoint(); + testFloatingPoint(); +} + VELOX_INSTANTIATE_TEST_SUITE_P( ArraySortTest, ArraySortTest, diff --git a/velox/functions/prestosql/tests/ArrayUnionTest.cpp b/velox/functions/prestosql/tests/ArrayUnionTest.cpp index dc57d21c9c70f..9173f3725f54c 100644 --- a/velox/functions/prestosql/tests/ArrayUnionTest.cpp +++ b/velox/functions/prestosql/tests/ArrayUnionTest.cpp @@ -32,6 +32,38 @@ class ArrayUnionTest : public FunctionBaseTest { auto result = evaluate(expression, makeRowVector(input)); assertEqualVectors(expected, result); } + + template + void floatArrayTest() { + static const T kQuietNaN = std::numeric_limits::quiet_NaN(); + static const T kSignalingNaN = std::numeric_limits::signaling_NaN(); + static const T kInfinity = std::numeric_limits::infinity(); + const auto array1 = makeArrayVector( + {{1.1, 2.2, 3.3, 4.4}, + {3.3, 4.4}, + {3.3, 4.4, kQuietNaN}, + {3.3, 4.4, kQuietNaN}, + {3.3, 4.4, kQuietNaN}, + {3.3, 4.4, kQuietNaN, kInfinity}}); + const auto array2 = makeArrayVector( + {{3.3, 4.4}, + {3.3, 5.5}, + {5.5}, + {3.3, kQuietNaN}, + {5.5, kSignalingNaN}, + {5.5, kInfinity}}); + VectorPtr expected; + + expected = makeArrayVector({ + {1.1, 2.2, 3.3, 4.4}, + {3.3, 4.4, 5.5}, + {3.3, 4.4, kQuietNaN, 5.5}, + {3.3, 4.4, kQuietNaN}, + {3.3, 4.4, kQuietNaN, 5.5}, + {3.3, 4.4, kQuietNaN, kInfinity, 5.5}, + }); + testExpression("array_union(c0, c1)", {array1, array2}, expected); + } }; /// Union two integer arrays. @@ -129,4 +161,11 @@ TEST_F(ArrayUnionTest, complexTypes) { testExpression( "array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected); } + +/// Union two floating point arrays including extreme values like infinity and +/// NaN. +TEST_F(ArrayUnionTest, floatingPointType) { + floatArrayTest(); + floatArrayTest(); +} } // namespace diff --git a/velox/functions/prestosql/tests/ArraysOverlapTest.cpp b/velox/functions/prestosql/tests/ArraysOverlapTest.cpp index 57eccbc6362eb..74bf36678cb0c 100644 --- a/velox/functions/prestosql/tests/ArraysOverlapTest.cpp +++ b/velox/functions/prestosql/tests/ArraysOverlapTest.cpp @@ -80,6 +80,9 @@ class ArraysOverlapTest : public FunctionBaseTest { std::numeric_limits::infinity(), std::numeric_limits::max()}, {std::numeric_limits::quiet_NaN(), 9.0009}, + {std::numeric_limits::quiet_NaN(), 3.1}, + // quiet NaN and signaling NaN are treated equal + {std::numeric_limits::signaling_NaN(), 3.1}, {std::numeric_limits::quiet_NaN(), 9.0009, std::nullopt}}); auto array2 = makeNullableArrayVector( {{1.0, -2.0, 4.0}, @@ -87,9 +90,11 @@ class ArraysOverlapTest : public FunctionBaseTest { {1.0001, -2.02, std::numeric_limits::max(), 8.00099}, {9.0009, std::numeric_limits::infinity()}, {9.0009, std::numeric_limits::quiet_NaN()}, + {9.0009, std::numeric_limits::quiet_NaN()}, + {9.0009, std::numeric_limits::quiet_NaN()}, {9.0}}); auto expected = makeNullableFlatVector( - {true, true, true, true, true, std::nullopt}); + {true, true, true, true, true, true, true, std::nullopt}); testExpr(expected, "arrays_overlap(C0, C1)", {array1, array2}); testExpr(expected, "arrays_overlap(C1, C0)", {array1, array2}); } diff --git a/velox/type/FloatingPointUtil.h b/velox/type/FloatingPointUtil.h index 63f100bcaa4f4..c2fa80857593f 100644 --- a/velox/type/FloatingPointUtil.h +++ b/velox/type/FloatingPointUtil.h @@ -20,6 +20,8 @@ #include #include +#include + namespace facebook::velox { /// Custom comparator and hash functors for floating point types. These are @@ -80,6 +82,30 @@ struct NaNAwareHash { return std::hash{}(val); } }; + +// Utility struct to provide a clean way of defining a hash set type using +// folly::F14FastSet with overrides for floating point types. +template +struct HashSetTypeTraits { + using Type = folly::F14FastSet; +}; + +// Ensure Accumulator treats NaNs as equal. +template <> +struct HashSetTypeTraits { + using Type = folly::F14FastSet< + float, + util::floating_point::NaNAwareHash, + util::floating_point::NaNAwareEquals>; +}; + +template <> +struct HashSetTypeTraits { + using Type = folly::F14FastSet< + double, + util::floating_point::NaNAwareHash, + util::floating_point::NaNAwareEquals>; +}; } // namespace util::floating_point /// A static class that holds helper functions for DOUBLE type.