Skip to content

Commit

Permalink
Fix NaN handling for in-predicate (facebookincubator#10115)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookincubator#10115

Ensure NaN values of different binary representation for floating
point types are considered as equal.

Summary of changes:
- Primitive type Input: NaN of different binary representations are
denormalized to the same representation before adding to the
in-list and before being compared.
- Complex Type input: Uses a set that employs hash and equality
functions via BaseVector that have been fixed in facebookincubator#9963 to handle
NaN values.

Reviewed By: kagamiori

Differential Revision: D58301120
  • Loading branch information
Bikramjeet Vig authored and facebook-github-bot committed Jun 21, 2024
1 parent c97e7fc commit 4010a84
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 18 deletions.
35 changes: 17 additions & 18 deletions velox/functions/prestosql/InPredicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,15 @@ createFloatingPointValuesFilter(
VELOX_USER_CHECK(
!values.empty(),
"IN predicate expects at least one non-null value in the in-list");

if (values.size() == 1) {
return {
std::make_unique<common::FloatingPointRange<T>>(
values[0], false, false, values[0], false, false, nullAllowed),
false};
}

// Avoid using FloatingPointRange for optimization of a single value in-list
// as it does not support NaN as a bound for specifying a range.
std::vector<int64_t> intValues(values.size());
for (size_t i = 0; i < values.size(); ++i) {
if (std::isnan(values[i])) {
// We de-normalize NaN values to ensure different binary representations
// are treated the same.
values[i] = std::numeric_limits<T>::quiet_NaN();
}
if constexpr (std::is_same_v<T, float>) {
if (values[i] == float{}) {
values[i] = 0;
Expand Down Expand Up @@ -411,26 +410,26 @@ class InPredicate : public exec::VectorFunction {
break;
case TypeKind::REAL:
applyTyped<float>(rows, input, context, result, [&](float value) {
auto* derived =
dynamic_cast<common::FloatingPointRange<float>*>(filter_.get());
if (derived) {
return filter_->testFloat(value);
}
if (value == float{}) {
value = 0;
} else if (std::isnan(value)) {
// We de-normalize NaN values to ensure different binary
// representations
// are treated the same.
value = std::numeric_limits<float>::quiet_NaN();
}
return filter_->testInt64(reinterpret_cast<const int32_t&>(value));
});
break;
case TypeKind::DOUBLE:
applyTyped<double>(rows, input, context, result, [&](double value) {
auto* derived =
dynamic_cast<common::FloatingPointRange<double>*>(filter_.get());
if (derived) {
return filter_->testDouble(value);
}
if (value == double{}) {
value = 0;
} else if (std::isnan(value)) {
// We de-normalize NaN values to ensure different binary
// representations
// are treated the same.
value = std::numeric_limits<double>::quiet_NaN();
}
return filter_->testInt64(reinterpret_cast<const int64_t&>(value));
});
Expand Down
114 changes: 114 additions & 0 deletions velox/functions/prestosql/tests/InPredicateTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,114 @@ class InPredicateTest : public FunctionBaseTest {

return makeFlatVector(timestamps);
}

template <typename T>
void testNaNs() {
const T kNaN = std::numeric_limits<T>::quiet_NaN();
const T kSNaN = std::numeric_limits<T>::signaling_NaN();
TypePtr columnFloatType = CppToType<T>::create();

// Constant In-list, primitive input.
auto testInWithConstList = [&](std::vector<T> input,
std::vector<T> inlist,
std::vector<bool> expected) {
auto expr = std::make_shared<core::CallTypedExpr>(
BOOLEAN(),
std::vector<core::TypedExprPtr>{
field(columnFloatType, "c0"),
std::make_shared<core::ConstantTypedExpr>(
makeArrayVector<T>({inlist})),
},
"in");
auto data = makeRowVector({
makeFlatVector<T>(input),
});
auto expectedResults = makeFlatVector<bool>(expected);
auto result = evaluate(expr, data);
assertEqualVectors(expectedResults, result);
};

testInWithConstList({kNaN, kSNaN}, {kNaN, 1}, {true, true});
testInWithConstList({kNaN, kSNaN}, {1, 2}, {false, false});
// Need to specifically test in-list with a single element as it previously
// had a seperate codepath.
testInWithConstList({kNaN, kSNaN}, {kNaN}, {true, true});
testInWithConstList({kNaN, kSNaN}, {1}, {false, false});

{
// Constant In-list, complex input(row).
// In-list is [row{kNaN, 1}].
auto inlist = makeArrayVector(
{0},
makeRowVector(
{makeFlatVector<T>(std::vector<T>({kNaN})),
makeFlatVector<int32_t>(std::vector<int32_t>({1}))}));
auto expr = std::make_shared<core::CallTypedExpr>(
BOOLEAN(),
std::vector<core::TypedExprPtr>{
field(ROW({columnFloatType, INTEGER()}), "c0"),
std::make_shared<core::ConstantTypedExpr>(inlist),
},
"in");
// Input is [row{kNaN, 1}, row{kSNaN, 1}, row{kNaN, 2}].
auto data = makeRowVector({makeRowVector(
{makeFlatVector<T>(std::vector<T>({kNaN, kSNaN, kNaN})),
makeFlatVector<int32_t>(std::vector<int32_t>({1, 1, 2}))})});
auto expectedResults = makeFlatVector<bool>({true, true, false});
auto result = evaluate(expr, data);
assertEqualVectors(expectedResults, result);
}

{
// Variable In-list, primitive input.
auto data = makeRowVector({
makeFlatVector<T>({kNaN, kSNaN, kNaN}),
makeFlatVector<T>({kNaN, kNaN, 0}),
makeFlatVector<T>({1, 1, 1}),
});
// Expression: c0 in (c1, c2)
auto inWithVariableInList = std::make_shared<core::CallTypedExpr>(
BOOLEAN(),
std::vector<core::TypedExprPtr>{
field(columnFloatType, "c0"),
field(columnFloatType, "c1"),
field(columnFloatType, "c2"),
},
"in");
auto expectedResults = makeFlatVector<bool>({
true, // kNaN in (kNaN, 1)
true, // kSNaN in (kNaN, 1)
false, // kNaN in (kNaN, 0)
});
auto result = evaluate(inWithVariableInList, data);
assertEqualVectors(expectedResults, result);
}

{
// Variable In-list, complex input(row).
// Input is:
// c0: [row{kNaN, 1}, row{kSNaN, 1}, row{kNaN, 2}]
// c1: [row{kNaN, 1}, row{kNaN, 1}, row{kNaN, 1}]
auto data = makeRowVector(
{makeRowVector(
{makeFlatVector<T>(std::vector<T>({kNaN, kSNaN, kNaN})),
makeFlatVector<int32_t>(std::vector<int32_t>({1, 1, 2}))}),
makeRowVector(
{makeFlatVector<T>(std::vector<T>({kNaN, kNaN, kNaN})),
makeFlatVector<int32_t>(std::vector<int32_t>({1, 1, 1}))})});
// Expression: c0 in (c1)
auto inWithVariableInList = std::make_shared<core::CallTypedExpr>(
BOOLEAN(),
std::vector<core::TypedExprPtr>{
field(ROW({columnFloatType, INTEGER()}), "c0"),
field(ROW({columnFloatType, INTEGER()}), "c1"),
},
"in");
auto expectedResults = makeFlatVector<bool>({true, true, false});
auto result = evaluate(inWithVariableInList, data);
assertEqualVectors(expectedResults, result);
}
}
};

TEST_F(InPredicateTest, bigint) {
Expand Down Expand Up @@ -952,5 +1060,11 @@ TEST_F(InPredicateTest, nonConstantInList) {
assertEqualVectors(expected, result);
}

TEST_F(InPredicateTest, nans) {
// Ensure that NaNs with different bit patterns are treated as equal.
testNaNs<float>();
testNaNs<double>();
}

} // namespace
} // namespace facebook::velox::functions

0 comments on commit 4010a84

Please sign in to comment.