diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index 2a6208ce9156..05dea9101c17 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -94,30 +94,30 @@ Array Functions .. function:: array_max(array(E)) -> E Returns the maximum value of input array. - Returns NaN if E is REAL or DOUBLE and array contains a NaN value. - Returns NULL if array doesn't contain a NaN value, but contains a NULL value. :: + NaN is considered to be greater than Infinity. + Returns NULL if array contains a NULL value. :: SELECT array_max(ARRAY [1, 2, 3]); -- 3 SELECT array_max(ARRAY [-1, -2, -2]); -- -1 SELECT array_max(ARRAY [-1, -2, NULL]); -- NULL SELECT array_max(ARRAY []); -- NULL - SELECT array_max(ARRAY[NULL, nan()]); -- NaN + SELECT array_max(ARRAY [-1, nan(), NULL]); -- NULL SELECT array_max(ARRAY[{-1, -2, -3, nan()]); -- NaN - SELECT array_max(ARRAY[-0.0001, NULL, -0.0003, nan()]); -- NaN + SELECT array_max(ARRAY[{infinity(), nan()]); -- NaN .. function:: array_min(array(E)) -> E Returns the minimum value of input array. - Returns NaN if E is REAL or DOUBLE and array contains a NaN value. - Returns NULL if array doesn't contain a NaN value, but contains a NULL value. :: + NaN is considered to be greater than Infinity. + Returns NULL if array contains a NULL value. :: SELECT array_min(ARRAY [1, 2, 3]); -- 1 SELECT array_min(ARRAY [-1, -2, -2]); -- -2 SELECT array_min(ARRAY [-1, -2, NULL]); -- NULL SELECT array_min(ARRAY []); -- NULL - SELECT array_min(ARRAY[NULL, nan()]); -- NaN - SELECT array_min(ARRAY[{-1, -2, -3, nan()]); -- NaN - SELECT array_min(ARRAY[-0.0001, NULL, -0.0003, nan()]); -- NaN + SELECT array_min(ARRAY [-1, nan(), NULL]); -- NULL + SELECT array_min(ARRAY[{-1, -2, -3, nan()]); -- -1 + SELECT array_min(ARRAY[{infinity(), nan()]); -- Infinity .. function:: array_normalize(array(E), E) -> array(E) diff --git a/velox/functions/prestosql/ArrayFunctions.h b/velox/functions/prestosql/ArrayFunctions.h index b87079a4408d..9bb3321a8530 100644 --- a/velox/functions/prestosql/ArrayFunctions.h +++ b/velox/functions/prestosql/ArrayFunctions.h @@ -21,6 +21,7 @@ #include "velox/functions/Udf.h" #include "velox/functions/lib/CheckedArithmetic.h" #include "velox/type/Conversions.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { @@ -32,13 +33,27 @@ struct ArrayMinMaxFunction { template void update(T& currentValue, const T& candidateValue) { - if constexpr (isMax) { - if (candidateValue > currentValue) { - currentValue = candidateValue; + if constexpr (std::is_same_v || std::is_same_v) { + using facebook::velox::util::floating_point::NaNAwareGreaterThan; + using facebook::velox::util::floating_point::NaNAwareLessThan; + if constexpr (isMax) { + if (NaNAwareGreaterThan{}(candidateValue, currentValue)) { + currentValue = candidateValue; + } + } else { + if (NaNAwareLessThan{}(candidateValue, currentValue)) { + currentValue = candidateValue; + } } } else { - if (candidateValue < currentValue) { - currentValue = candidateValue; + if constexpr (isMax) { + if (candidateValue > currentValue) { + currentValue = candidateValue; + } + } else { + if (candidateValue < currentValue) { + currentValue = candidateValue; + } } } } @@ -52,59 +67,6 @@ struct ArrayMinMaxFunction { out.setNoCopy(value); } - template - bool callForFloatOrDouble(TReturn& out, const TInput& array) { - bool hasNull = false; - auto it = array.begin(); - - // Find the first non-null item (if any) - while (it != array.end()) { - if (it->has_value()) { - break; - } - - hasNull = true; - ++it; - } - - // Return false if end of array is reached without finding a non-null item. - if (it == array.end()) { - return false; - } - - // If first non-null item is NAN, return immediately. - auto currentValue = it->value(); - if (std::isnan(currentValue)) { - assign(out, currentValue); - return true; - } - - ++it; - while (it != array.end()) { - if (it->has_value()) { - auto newValue = it->value(); - if (std::isnan(newValue)) { - assign(out, newValue); - return true; - } - update(currentValue, newValue); - } else { - hasNull = true; - } - ++it; - } - - // If we found a null, return false. Note that, if we found - // a NAN, the function will return at earlier stage as soon as - // a NAN is observed. - if (hasNull) { - return false; - } - - assign(out, currentValue); - return true; - } - template FOLLY_ALWAYS_INLINE bool call(TReturn& out, const TInput& array) { // Result is null if array is empty. @@ -112,11 +74,6 @@ struct ArrayMinMaxFunction { return false; } - if constexpr ( - std::is_same_v || std::is_same_v) { - return callForFloatOrDouble(out, array); - } - if (!array.mayHaveNulls()) { // Input array does not have nulls. auto currentValue = *array[0]; diff --git a/velox/functions/prestosql/tests/ArrayMaxTest.cpp b/velox/functions/prestosql/tests/ArrayMaxTest.cpp index a74b536f7b6e..e1c1bcfad884 100644 --- a/velox/functions/prestosql/tests/ArrayMaxTest.cpp +++ b/velox/functions/prestosql/tests/ArrayMaxTest.cpp @@ -145,24 +145,21 @@ TEST_F(ArrayMaxTest, longVarcharNoNulls) { // Test documented example. TEST_F(ArrayMaxTest, docs) { - auto input1 = makeNullableArrayVector( - {{1, 2, 3}, {-1, -2, -2}, {-1, -2, std::nullopt}, {}}); - auto expected1 = - makeNullableFlatVector({3, -1, std::nullopt, std::nullopt}); - testArrayMax(input1, expected1); - - auto input2 = makeNullableArrayVector( - {{std::nullopt, std::numeric_limits::quiet_NaN()}, - {-1, -2, -3, std::numeric_limits::quiet_NaN()}, - {-0.0001, - std::nullopt, - -0.0003, - std::numeric_limits::quiet_NaN()}}); - auto expected2 = makeNullableFlatVector( - {std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}); - testArrayMax(input2, expected2); + { + auto input = makeNullableArrayVector( + {{1, 2, 3}, {-1, -2, -2}, {-1, -2, std::nullopt}, {}}); + auto expected = + makeNullableFlatVector({3, -1, std::nullopt, std::nullopt}); + testArrayMax(input, expected); + } + { + static const float kNaN = std::numeric_limits::quiet_NaN(); + static const float kInfinity = std::numeric_limits::infinity(); + auto input = makeNullableArrayVector( + {{-1, kNaN, std::nullopt}, {-1, -2, -3, kNaN}, {kInfinity, kNaN}}); + auto expected = makeNullableFlatVector({std::nullopt, kNaN, kNaN}); + testArrayMax(input, expected); + } } template @@ -293,6 +290,22 @@ class ArrayMaxFloatingPointTest : public FunctionBaseTest { 0.0001}); testArrayMax(input, expected); } + + void testExtremeValues() { + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kInfinity = std::numeric_limits::infinity(); + static const T kNegativeInfinity = -1 * std::numeric_limits::infinity(); + auto input = makeNullableArrayVector( + {{-1, std::nullopt, kNaN}, + {-1, std::nullopt, 2}, + {-1, 0, 2}, + {kNegativeInfinity, kNegativeInfinity}, + {-1, 2, kInfinity}, + {kInfinity, kNaN}}); + auto expected = makeNullableFlatVector( + {std::nullopt, std::nullopt, 2, kNegativeInfinity, kInfinity, kNaN}); + testArrayMax(input, expected); + } }; } // namespace @@ -318,3 +331,7 @@ TYPED_TEST(ArrayMaxFloatingPointTest, arrayMaxNullable) { TYPED_TEST(ArrayMaxFloatingPointTest, arrayMax) { this->testNoNulls(); } + +TYPED_TEST(ArrayMaxFloatingPointTest, arrayMaxExtreme) { + this->testExtremeValues(); +} diff --git a/velox/functions/prestosql/tests/ArrayMinTest.cpp b/velox/functions/prestosql/tests/ArrayMinTest.cpp index f3d3c9ca7c26..4b63f924ddbd 100644 --- a/velox/functions/prestosql/tests/ArrayMinTest.cpp +++ b/velox/functions/prestosql/tests/ArrayMinTest.cpp @@ -43,17 +43,12 @@ class ArrayMinTest : public FunctionBaseTest { makeNullableFlatVector({1, -2, std::nullopt, std::nullopt}); testExpr(expected1, "array_min(C0)", {input1}); + static const float kNaN = std::numeric_limits::quiet_NaN(); + static const float kInfinity = std::numeric_limits::infinity(); auto input2 = makeNullableArrayVector( - {{std::nullopt, std::numeric_limits::quiet_NaN()}, - {-1, -2, -3, std::numeric_limits::quiet_NaN()}, - {-0.0001, - std::nullopt, - -0.0003, - std::numeric_limits::quiet_NaN()}}); - auto expected2 = makeNullableFlatVector( - {std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()}); + {{-1, kNaN, std::nullopt}, {-1, -2, -3, kNaN}, {kInfinity, kNaN}}); + auto expected2 = + makeNullableFlatVector({std::nullopt, -3, kInfinity}); testExpr(expected2, "array_min(C0)", {input2}); } @@ -157,6 +152,24 @@ class ArrayMinTest : public FunctionBaseTest { {false, true, false, std::nullopt, false, true}); testExpr(expected, "array_min(C0)", {arrayVector}); } + + template + void testFloatingPoint() { + static const T kNaN = std::numeric_limits::quiet_NaN(); + static const T kInfinity = std::numeric_limits::infinity(); + static const T kNegativeInfinity = -1 * std::numeric_limits::infinity(); + auto input = makeNullableArrayVector( + {{-1, std::nullopt, kNaN}, + {-1, std::nullopt, 2}, + {-1, 0, 2}, + {-1, kNegativeInfinity, kNaN}, + {kInfinity, kNaN}, + {kNaN, kNaN}}); + auto expected = makeNullableFlatVector( + {std::nullopt, std::nullopt, -1, kNegativeInfinity, kInfinity, kNaN}); + + testExpr(expected, "array_min(C0)", {input}); + } }; } // namespace @@ -185,6 +198,11 @@ TEST_F(ArrayMinTest, boolArrays) { testBool(); } +TEST_F(ArrayMinTest, floatArrays) { + testFloatingPoint(); + testFloatingPoint(); +} + TEST_F(ArrayMinTest, complexTypeElements) { auto elements = makeRowVector({ makeFlatVector({1, 2, 3, 3, 2, 1}),