Skip to content

Commit

Permalink
Fix NaN handling for multiple array UDFs (facebookincubator#9797)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#9797

Fixes for the following:
array_sort
array_distinct
array_intersect
array_except
array_overlap
array_union

Also updated documentation.

Differential Revision: D57305880
  • Loading branch information
bikramSingh91 authored and facebook-github-bot committed May 16, 2024
1 parent a2e3d86 commit c6be425
Show file tree
Hide file tree
Showing 12 changed files with 173 additions and 32 deletions.
16 changes: 13 additions & 3 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ Array Functions

.. function:: array_distinct(array(E)) -> array(E)

Remove duplicate values from the input array. ::
Remove duplicate values from the input array.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT array_distinct(ARRAY [1, 2, 3]); -- [1, 2, 3]
SELECT array_distinct(ARRAY [1, 2, 1]); -- [1, 2]
Expand All @@ -50,7 +51,8 @@ Array Functions

.. function:: array_except(array(E) x, array(E) y) -> array(E)

Returns an array of the elements in array ``x`` but not in array ``y``, without duplicates. ::
Returns an array of the elements in array ``x`` but not in array ``y``, without duplicates.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT array_except(ARRAY [1, 2, 3], ARRAY [4, 5, 6]); -- [1, 2, 3]
SELECT array_except(ARRAY [1, 2, 3], ARRAY [1, 2]); -- [3]
Expand All @@ -77,7 +79,8 @@ Array Functions

.. function:: array_intersect(array(E) x, array(E) y) -> array(E)

Returns an array of the elements in the intersection of array ``x`` and array ``y``, without duplicates. ::
Returns an array of the elements in the intersection of array ``x`` and array ``y``, without duplicates.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal. ::

SELECT array_intersect(ARRAY [1, 2, 3], ARRAY[4, 5, 6]); -- []
SELECT array_intersect(ARRAY [1, 2, 2], ARRAY[1, 1, 2]); -- [1, 2]
Expand Down Expand Up @@ -127,6 +130,12 @@ Array Functions

Tests if arrays ``x`` and ``y`` have any non-null elements in common.
Returns null if there are no non-null elements in common but either array contains null.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal.

.. function:: arrays_union(x, y) -> array

Returns an array of the elements in the union of x and y, without duplicates.
For REAL and DOUBLE, NANs (Not-a-Number) are considered equal.

.. function:: array_position(x, element) -> bigint

Expand All @@ -153,6 +162,7 @@ Array Functions

SELECT array_sort(ARRAY [1, 2, 3]); -- [1, 2, 3]
SELECT array_sort(ARRAY [3, 2, 1]); -- [1, 2, 3]
SELECT array_sort(ARRAY [infinity(), -1.1, nan(), 1.1, -Infinity(), 0])); -- [-Infinity, -1.1, 0, 1.1, Infinity, NaN]
SELECT array_sort(ARRAY [2, 1, NULL]; -- [1, 2, NULL]
SELECT array_sort(ARRAY [NULL, 1, NULL]); -- [1, NULL, NULL]
SELECT array_sort(ARRAY [NULL, 2, 1]); -- [1, 2, NULL]
Expand Down
6 changes: 4 additions & 2 deletions velox/functions/prestosql/ArrayDistinct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
* limitations under the License.
*/

#include <folly/container/F14Set.h>

#include "velox/expression/EvalCtx.h"
#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {
namespace {
Expand Down Expand Up @@ -96,7 +96,9 @@ class ArrayDistinctFunction : public exec::VectorFunction {
auto* rawNewOffsets = newOffsets->asMutable<vector_size_t>();

// Process the rows: store unique values in the hash table.
folly::F14FastSet<T> uniqueSet;
using HashSetType =
typename util::floating_point::HashSetTypeTraits<T>::Type;
HashSetType uniqueSet;

rows.applyToSelected([&](vector_size_t row) {
auto size = arrayVector->sizeAt(row);
Expand Down
12 changes: 7 additions & 5 deletions velox/functions/prestosql/ArrayFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ struct ArrayMinMaxFunction {
template <typename T>
void update(T& currentValue, const T& candidateValue) {
if constexpr (std::is_same_v<T, float> || std::is_same_v<T, double>) {
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
using facebook::velox::util::floating_point::NaNAwareLessThan;
if constexpr (isMax) {
if (NaNAwareGreaterThan<T>{}(candidateValue, currentValue)) {
if (util::floating_point::NaNAwareGreaterThan<T>{}(
candidateValue, currentValue)) {
currentValue = candidateValue;
}
} else {
if (NaNAwareLessThan<T>{}(candidateValue, currentValue)) {
if (util::floating_point::NaNAwareLessThan<T>{}(
candidateValue, currentValue)) {
currentValue = candidateValue;
}
}
Expand Down Expand Up @@ -836,7 +836,9 @@ struct ArrayUnionFunction {

template <typename Out, typename In>
void call(Out& out, const In& inputArray1, const In& inputArray2) {
folly::F14FastSet<typename In::element_t> elementSet;
using HashSetType = typename util::floating_point::HashSetTypeTraits<
typename In::element_t>::Type;
HashSetType elementSet;
bool nullAdded = false;
auto addItems = [&](auto& inputArray) {
for (const auto& item : inputArray) {
Expand Down
4 changes: 3 additions & 1 deletion velox/functions/prestosql/ArrayIntersectExcept.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "velox/expression/VectorFunction.h"
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {
namespace {
Expand All @@ -31,7 +32,8 @@ struct SetWithNull {
hasNull = false;
}

folly::F14FastSet<T> set;
using HashSetType = typename util::floating_point::HashSetTypeTraits<T>::Type;
HashSetType set;
bool hasNull{false};
static constexpr vector_size_t kInitialSetSize{128};
};
Expand Down
14 changes: 14 additions & 0 deletions velox/functions/prestosql/ArraySort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "velox/functions/lib/LambdaFunctionUtil.h"
#include "velox/functions/lib/RowsTranslationUtil.h"
#include "velox/functions/prestosql/SimpleComparisonMatcher.h"
#include "velox/type/FloatingPointUtil.h"

namespace facebook::velox::functions {
namespace {
Expand Down Expand Up @@ -172,6 +173,19 @@ void applyScalarType(
bits::fillBits(rawBits, startRow, startRow + numOneBits, true);
bits::fillBits(rawBits, endZeroRow, endRow, false);
}
} else if constexpr (kind == TypeKind::REAL || kind == TypeKind::DOUBLE) {
T* resultRawValues = flatResults->mutableRawValues();
if (ascending) {
std::sort(
resultRawValues + startRow,
resultRawValues + endRow,
util::floating_point::NaNAwareLessThan<T>());
} else {
std::sort(
resultRawValues + startRow,
resultRawValues + endRow,
util::floating_point::NaNAwareGreaterThan<T>());
}
} else {
T* resultRawValues = flatResults->mutableRawValues();
if (ascending) {
Expand Down
10 changes: 6 additions & 4 deletions velox/functions/prestosql/tests/ArrayDistinctTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class ArrayDistinctTest : public FunctionBaseTest {
{0.0, -10.0},
{std::numeric_limits<T>::quiet_NaN(),
std::numeric_limits<T>::quiet_NaN()},
{std::numeric_limits<T>::quiet_NaN(),
std::numeric_limits<T>::signaling_NaN()},
{std::numeric_limits<T>::signaling_NaN(),
std::numeric_limits<T>::signaling_NaN()},
{std::numeric_limits<T>::lowest(), std::numeric_limits<T>::lowest()},
Expand Down Expand Up @@ -134,10 +136,10 @@ class ArrayDistinctTest : public FunctionBaseTest {
{0.0},
{0.0, 10.0},
{0.0, -10.0},
{std::numeric_limits<T>::quiet_NaN(),
std::numeric_limits<T>::quiet_NaN()},
{std::numeric_limits<T>::signaling_NaN(),
std::numeric_limits<T>::signaling_NaN()},
{std::numeric_limits<T>::quiet_NaN()},
// quiet NaN and signaling NaN are treated equal
{std::numeric_limits<T>::quiet_NaN()},
{std::numeric_limits<T>::signaling_NaN()},
{std::numeric_limits<T>::lowest()},
{std::nullopt},
{1.0001, -2.0, 3.03, std::nullopt, 4.00004},
Expand Down
35 changes: 21 additions & 14 deletions velox/functions/prestosql/tests/ArrayExceptTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,31 +107,38 @@ class ArrayExceptTest : public FunctionBaseTest {
std::numeric_limits<T>::infinity(),
std::numeric_limits<T>::max()},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
});
auto array2 = makeNullableArrayVector<T>({
{1.0, -2.0, 4.0},
{std::numeric_limits<T>::min(), 2.0199, -2.001, 1.000001},
{1.0001, -2.02, std::numeric_limits<T>::max(), 8.00099},
{9.0009, std::numeric_limits<T>::infinity()},
{9.0009, std::numeric_limits<T>::quiet_NaN()},
});

auto expected = makeNullableArrayVector<T>({
{1.0001, 3.03, std::nullopt, 4.00004},
{2.02, 1},
{8.0001, std::nullopt},
{std::numeric_limits<T>::max()},
{std::numeric_limits<T>::quiet_NaN()},
// quiet NaN and signaling NaN are treated equal
{std::numeric_limits<T>::signaling_NaN()},
});

auto expected = makeNullableArrayVector<T>(
{{1.0001, 3.03, std::nullopt, 4.00004},
{2.02, 1},
{8.0001, std::nullopt},
{std::numeric_limits<T>::max()},
{},
{9.0009},
{9.0009}});
testExpr(expected, "array_except(C0, C1)", {array1, array2});

expected = makeNullableArrayVector<T>({
{1.0, 4.0},
{2.0199, 1.000001},
{1.0001, -2.02, 8.00099},
{},
{std::numeric_limits<T>::quiet_NaN()},
});
expected = makeNullableArrayVector<T>(
{{1.0, 4.0},
{2.0199, 1.000001},
{1.0001, -2.02, 8.00099},
{},
{},
{},
{}});
testExpr(expected, "array_except(C1, C0)", {array1, array2});
}
};
Expand Down
8 changes: 6 additions & 2 deletions velox/functions/prestosql/tests/ArrayIntersectTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,24 @@ class ArrayIntersectTest : public FunctionBaseTest {
std::numeric_limits<T>::infinity(),
std::numeric_limits<T>::max()},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
});
auto array2 = makeNullableArrayVector<T>({
{1.0, -2.0, 4.0},
{std::numeric_limits<T>::min(), 2.0199, -2.001, 1.000001},
{1.0001, -2.02, std::numeric_limits<T>::max(), 8.00099},
{9.0009, std::numeric_limits<T>::infinity()},
{9.0009, std::numeric_limits<T>::quiet_NaN()},
{std::numeric_limits<T>::quiet_NaN()},
// quiet NaN and signaling NaN are treated equal
{std::numeric_limits<T>::signaling_NaN(), 9.0009},
});
auto expected = makeNullableArrayVector<T>({
{-2.0},
{std::numeric_limits<T>::min(), -2.001},
{std::numeric_limits<T>::max()},
{9.0009, std::numeric_limits<T>::infinity()},
{9.0009},
{std::numeric_limits<T>::quiet_NaN()},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
});

testExpr(expected, "array_intersect(C0, C1)", {array1, array2});
Expand Down
28 changes: 28 additions & 0 deletions velox/functions/prestosql/tests/ArraySortTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,29 @@ class ArraySortTest : public FunctionBaseTest,
}
}

template <typename T>
void testFloatingPoint() {
// Verify that NaNs are treated as greater than infinity
static const T kNaN = std::numeric_limits<T>::quiet_NaN();
static const T kInfinity = std::numeric_limits<T>::infinity();
static const T kNegativeInfinity = -1 * std::numeric_limits<T>::infinity();

auto input = makeRowVector({makeNullableArrayVector<T>(
{{kInfinity, -1, kNaN, 1, kNegativeInfinity, kNaN, 0}})});

{
auto expected = makeNullableArrayVector<T>(
{{kNegativeInfinity, -1, 0, 1, kInfinity, kNaN, kNaN}});
assertEqualVectors(expected, evaluate("try(array_sort(c0))", input));
}

{
auto expected = makeNullableArrayVector<T>(
{{kNaN, kNaN, kInfinity, 1, 0, -1, kNegativeInfinity}});
assertEqualVectors(expected, evaluate("try(array_sort_desc(c0))", input));
}
}

// Specify the number of values per each data vector in 'dataVectorsByType_'.
const int numValues_;
std::unordered_map<TypeKind, VectorPtr> dataVectorsByType_;
Expand Down Expand Up @@ -680,6 +703,11 @@ TEST_F(ArraySortTest, failOnRowNullCompare) {
}
}

TEST_F(ArraySortTest, floatingPointExtremes) {
testFloatingPoint<float>();
testFloatingPoint<double>();
}

VELOX_INSTANTIATE_TEST_SUITE_P(
ArraySortTest,
ArraySortTest,
Expand Down
39 changes: 39 additions & 0 deletions velox/functions/prestosql/tests/ArrayUnionTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@ class ArrayUnionTest : public FunctionBaseTest {
auto result = evaluate(expression, makeRowVector(input));
assertEqualVectors(expected, result);
}

template <typename T>
void floatArrayTest() {
static const T kQuietNaN = std::numeric_limits<T>::quiet_NaN();
static const T kSignalingNaN = std::numeric_limits<T>::signaling_NaN();
static const T kInfinity = std::numeric_limits<T>::infinity();
const auto array1 = makeArrayVector<T>(
{{1.1, 2.2, 3.3, 4.4},
{3.3, 4.4},
{3.3, 4.4, kQuietNaN},
{3.3, 4.4, kQuietNaN},
{3.3, 4.4, kQuietNaN},
{3.3, 4.4, kQuietNaN, kInfinity}});
const auto array2 = makeArrayVector<T>(
{{3.3, 4.4},
{3.3, 5.5},
{5.5},
{3.3, kQuietNaN},
{5.5, kSignalingNaN},
{5.5, kInfinity}});
VectorPtr expected;

expected = makeArrayVector<T>({
{1.1, 2.2, 3.3, 4.4},
{3.3, 4.4, 5.5},
{3.3, 4.4, kQuietNaN, 5.5},
{3.3, 4.4, kQuietNaN},
{3.3, 4.4, kQuietNaN, 5.5},
{3.3, 4.4, kQuietNaN, kInfinity, 5.5},
});
testExpression("array_union(c0, c1)", {array1, array2}, expected);
}
};

/// Union two integer arrays.
Expand Down Expand Up @@ -129,4 +161,11 @@ TEST_F(ArrayUnionTest, complexTypes) {
testExpression(
"array_union(c0, c1)", {arrayOfArrays1, arrayOfArrays2}, expected);
}

/// Union two floating point arrays including extreme values like infinity and
/// NaN.
TEST_F(ArrayUnionTest, floatingPointType) {
floatArrayTest<float>();
floatArrayTest<double>();
}
} // namespace
7 changes: 6 additions & 1 deletion velox/functions/prestosql/tests/ArraysOverlapTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,21 @@ class ArraysOverlapTest : public FunctionBaseTest {
std::numeric_limits<T>::infinity(),
std::numeric_limits<T>::max()},
{std::numeric_limits<T>::quiet_NaN(), 9.0009},
{std::numeric_limits<T>::quiet_NaN(), 3.1},
// quiet NaN and signaling NaN are treated equal
{std::numeric_limits<T>::signaling_NaN(), 3.1},
{std::numeric_limits<T>::quiet_NaN(), 9.0009, std::nullopt}});
auto array2 = makeNullableArrayVector<T>(
{{1.0, -2.0, 4.0},
{std::numeric_limits<T>::min(), 2.0199, 1.000001},
{1.0001, -2.02, std::numeric_limits<T>::max(), 8.00099},
{9.0009, std::numeric_limits<T>::infinity()},
{9.0009, std::numeric_limits<T>::quiet_NaN()},
{9.0009, std::numeric_limits<T>::quiet_NaN()},
{9.0009, std::numeric_limits<T>::quiet_NaN()},
{9.0}});
auto expected = makeNullableFlatVector<bool>(
{true, true, true, true, true, std::nullopt});
{true, true, true, true, true, true, true, std::nullopt});
testExpr(expected, "arrays_overlap(C0, C1)", {array1, array2});
testExpr(expected, "arrays_overlap(C1, C0)", {array1, array2});
}
Expand Down
Loading

0 comments on commit c6be425

Please sign in to comment.