Skip to content

Commit

Permalink
Tighten checks for "bins" elements in width_bucket(x, bins) (#11629)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #11629

Make width_bucket(x, bins) throw error if it finds a null or
non-finite element in bins.
As per new Presto Java behavior: prestodb/presto#24103

Reviewed By: Yuhta

Differential Revision: D66382264

fbshipit-source-id: 84f549dfa313e1794551fba9d71f8b87eb3b713e
  • Loading branch information
Sergey Pershin authored and facebook-github-bot committed Nov 23, 2024
1 parent 059337f commit bf3fba7
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 41 deletions.
8 changes: 6 additions & 2 deletions velox/docs/functions/presto/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,12 +192,16 @@ Mathematical Functions
:noindex:

Returns the zero-based bin number of ``x`` according to the bins specified
by the array ``bins``. The ``bins`` parameter must be an array of doubles and
is assumed to be in sorted ascending order.
by the array ``bins``. The ``bins`` parameter must be an array of doubles, should not
contain ``null`` or non-finite elements, and is assumed to be in sorted ascending order.

For example, if ``bins`` is ``ARRAY[0, 2, 4]``, then we have four bins:
``(-infinity(), 0)``, ``[0, 2)``, ``[2, 4)`` and ``[4, infinity())``.

Note: The function returns an error if it encounters a ``null`` or non-finite
element in ``bins``, but due to the binary search algorithm some such elements
might go unnoticed and the function will return a result.


====================================
Trigonometric Functions
Expand Down
27 changes: 17 additions & 10 deletions velox/functions/prestosql/WidthBucketArray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,22 @@ int64_t widthBucket(
int lower = 0;
int upper = binCount;
while (lower < upper) {
VELOX_USER_CHECK_LE(
elementsHolder.valueAt<T>(offset + lower),
elementsHolder.valueAt<T>(offset + upper - 1),
"Bin values are not sorted in ascending order");

int index = (lower + upper) / 2;
auto bin = elementsHolder.valueAt<T>(offset + index);
const int index = (lower + upper) / 2;
VELOX_USER_CHECK(
!elementsHolder.isNullAt(lower) && !elementsHolder.isNullAt(index) &&
!elementsHolder.isNullAt(upper - 1),
"Bin values cannot be NULL");

VELOX_USER_CHECK(std::isfinite(bin), "Bin value must be finite");
const auto bin = elementsHolder.valueAt<T>(offset + index);
const auto lowerBin = elementsHolder.valueAt<T>(offset + lower);
const auto upperBin = elementsHolder.valueAt<T>(offset + upper - 1);
VELOX_USER_CHECK(
lowerBin <= bin && bin <= upperBin,
"Bin values are not sorted in ascending order");
VELOX_USER_CHECK(
std::isfinite(bin) && std::isfinite(lowerBin) &&
std::isfinite(upperBin),
"Bin values must be finite");

if (operand < bin) {
upper = index;
Expand Down Expand Up @@ -162,9 +169,9 @@ std::vector<double> toBinValues(

for (int i = 0; i < size; i++) {
VELOX_USER_CHECK(
!simpleVector->isNullAt(offset + i), "Bin value cannot be null");
!simpleVector->isNullAt(offset + i), "Bin values cannot be null");
auto value = simpleVector->valueAt(offset + i);
VELOX_USER_CHECK(std::isfinite(value), "Bin value must be finite");
VELOX_USER_CHECK(std::isfinite(value), "Bin values must be finite");
if (i > 0) {
VELOX_USER_CHECK_GT(
value,
Expand Down
85 changes: 56 additions & 29 deletions velox/functions/prestosql/tests/WidthBucketArrayTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,39 +61,51 @@ TEST_F(WidthBucketArrayTest, success) {
assertEqualVectors(dictExpected, dictResult);
};

{
binsVector = makeArrayVector<double>({{0.0, 2.0, 4.0}, {0.0}});
testWidthBucketArray(3.14, {2, 1});
testWidthBucketArray(kInf, {3, 1});
testWidthBucketArray(-1, {0, 0});
}

{
binsVector = makeArrayVector<int64_t>({{0, 2, 4}, {0}});
testWidthBucketArray(3.14, {2, 1});
testWidthBucketArray(kInf, {3, 1});
testWidthBucketArray(-1, {0, 0});
}
binsVector = makeArrayVector<double>({{0.0, 2.0, 4.0}, {0.0}});
testWidthBucketArray(3.14, {2, 1});
testWidthBucketArray(kInf, {3, 1});
testWidthBucketArray(-1, {0, 0});

binsVector = makeArrayVector<int64_t>({{0, 2, 4}, {0}});
testWidthBucketArray(3.14, {2, 1});
testWidthBucketArray(kInf, {3, 1});
testWidthBucketArray(-1, {0, 0});

// Cases we cannot catch due to the binary search algorithm.
binsVector = makeNullableArrayVector<double>(
{{0.0, std::nullopt, 2.0, 4.0},
{0.0, std::nullopt, 1.0, 2.0, 4.0},
{0.0, kInf, 1.0, 2.0, 4.0}});
testWidthBucketArray(3.14, {3, 4, 4});
}

TEST_F(WidthBucketArrayTest, failure) {
auto testFailure = [&](const double operand,
const std::vector<std::vector<double>>& bins,
const std::string& expected_message) {
auto binsVector = makeArrayVector<double>(bins);
VELOX_ASSERT_THROW(
evaluate<SimpleVector<int64_t>>(
"width_bucket(c0, c1)",
makeRowVector(
{makeConstant(operand, binsVector->size()), binsVector})),
expected_message);
};
auto testFailure =
[&](const double operand,
const std::vector<std::vector<std::optional<double>>>& bins,
const std::string& expected_message) {
auto binsVector = makeNullableArrayVector<double>(bins);
VELOX_ASSERT_THROW(
evaluate<SimpleVector<int64_t>>(
"width_bucket(c0, c1)",
makeRowVector(
{makeConstant(operand, binsVector->size()), binsVector})),
expected_message);
};

testFailure(0, {{}}, "Bins cannot be an empty array");
testFailure(kNan, {{0}}, "Operand cannot be NaN");
testFailure(1, {{0, kInf}}, "Bin value must be finite");
testFailure(1, {{0, kInf}}, "Bin values must be finite");
testFailure(1, {{0, kNan}}, "Bin values are not sorted in ascending order");
testFailure(2, {{1, 0}}, "Bin values are not sorted in ascending order");
testFailure(
3.14, {{0, kInf, 10}}, "Bin values are not sorted in ascending order");
testFailure(
1.5, {{1.0, 2, 3, 2, 0}}, "Bin values are not sorted in ascending order");
testFailure(3.14, {{std::nullopt}}, "Bin values cannot be NULL");
testFailure(3.14, {{0.0, std::nullopt, 4.0}}, "Bin values cannot be NULL");
testFailure(
3.14, {{0.0, 2.0, 4.0, std::nullopt}}, "Bin values cannot be NULL");
}

TEST_F(WidthBucketArrayTest, successForConstantArray) {
Expand All @@ -120,9 +132,16 @@ TEST_F(WidthBucketArrayTest, successForConstantArray) {
testWidthBucketArray(3.14, "ARRAY[0.0]", 1);
testWidthBucketArray(kInf, "ARRAY[0.0]", 1);
testWidthBucketArray(-1, "ARRAY[0.0]", 0);

// Cases we cannot catch due to the binary search algorithm.
// If the 'bins' vector has issues we simply fall back to the non-constant
// case.
testWidthBucketArray(3.14, "ARRAY[0.0, NULL, 2.0, 4.0]", 3);
testWidthBucketArray(3.14, "ARRAY[0.0, NULL, 1.0, 2.0, 4.0]", 4);
testWidthBucketArray(3.14, "ARRAY[0.0, Infinity(), 1.0, 2.0, 4.0]", 4);
}

TEST_F(WidthBucketArrayTest, failureForConstant) {
TEST_F(WidthBucketArrayTest, failureForConstantArray) {
auto testFailure = [&](const double operand,
const std::string& bins,
const std::string& expected_message) {
Expand All @@ -133,12 +152,20 @@ TEST_F(WidthBucketArrayTest, failureForConstant) {
expected_message);
};

// TODO: Add tests for empty bin and bins that contains infinity(), nan()
// once corresponding casting and non-constant array literal element is
// supported.
testFailure(kNan, "ARRAY[0.0]", "Operand cannot be NaN");
testFailure(
2, "ARRAY[1.0, 0.0]", "Bin values are not sorted in ascending order");
testFailure(
3.14,
"ARRAY[0.0, Infinity(), 10.0]",
"Bin values are not sorted in ascending order");
testFailure(
1.5,
"ARRAY[1.0, 2.0, 3.0, 2.0, 0.0]",
"Bin values are not sorted in ascending order");
testFailure(3.14, "ARRAY[cast(NULL as double)]", "Bin values cannot be NULL");
testFailure(3.14, "ARRAY[0.0, NULL, 4.0]", "Bin values cannot be NULL");
testFailure(3.14, "ARRAY[0.0, 2.0, 4.0, NULL]", "Bin values cannot be NULL");
}

TEST_F(WidthBucketArrayTest, makeWidthBucketArrayNoThrow) {
Expand Down

0 comments on commit bf3fba7

Please sign in to comment.