diff --git a/velox/functions/prestosql/InPredicate.cpp b/velox/functions/prestosql/InPredicate.cpp index 769ac3c30933..78190a2302c8 100644 --- a/velox/functions/prestosql/InPredicate.cpp +++ b/velox/functions/prestosql/InPredicate.cpp @@ -191,16 +191,15 @@ createFloatingPointValuesFilter( VELOX_USER_CHECK( !values.empty(), "IN predicate expects at least one non-null value in the in-list"); - - if (values.size() == 1) { - return { - std::make_unique>( - values[0], false, false, values[0], false, false, nullAllowed), - false}; - } - + // Avoid using FloatingPointRange for optimization of a single value in-list + // as it does not support NaN as a bound for specifying a range. std::vector intValues(values.size()); for (size_t i = 0; i < values.size(); ++i) { + if (std::isnan(values[i])) { + // We de-normalize NaN values to ensure different binary representations + // are treated the same. + values[i] = std::numeric_limits::quiet_NaN(); + } if constexpr (std::is_same_v) { if (values[i] == float{}) { values[i] = 0; @@ -411,26 +410,26 @@ class InPredicate : public exec::VectorFunction { break; case TypeKind::REAL: applyTyped(rows, input, context, result, [&](float value) { - auto* derived = - dynamic_cast*>(filter_.get()); - if (derived) { - return filter_->testFloat(value); - } if (value == float{}) { value = 0; + } else if (std::isnan(value)) { + // We de-normalize NaN values to ensure different binary + // representations + // are treated the same. + value = std::numeric_limits::quiet_NaN(); } return filter_->testInt64(reinterpret_cast(value)); }); break; case TypeKind::DOUBLE: applyTyped(rows, input, context, result, [&](double value) { - auto* derived = - dynamic_cast*>(filter_.get()); - if (derived) { - return filter_->testDouble(value); - } if (value == double{}) { value = 0; + } else if (std::isnan(value)) { + // We de-normalize NaN values to ensure different binary + // representations + // are treated the same. + value = std::numeric_limits::quiet_NaN(); } return filter_->testInt64(reinterpret_cast(value)); }); diff --git a/velox/functions/prestosql/tests/InPredicateTest.cpp b/velox/functions/prestosql/tests/InPredicateTest.cpp index 07ba60b402e7..d1733eaf3c36 100644 --- a/velox/functions/prestosql/tests/InPredicateTest.cpp +++ b/velox/functions/prestosql/tests/InPredicateTest.cpp @@ -258,6 +258,114 @@ class InPredicateTest : public FunctionBaseTest { return makeFlatVector(timestamps); } + + template + void testNaNs() { + const T kNaN = std::numeric_limits::quiet_NaN(); + const T kSNaN = std::numeric_limits::signaling_NaN(); + TypePtr columnFloatType = CppToType::create(); + + // Constant In-list, primitive input. + auto testInWithConstList = [&](std::vector input, + std::vector inlist, + std::vector expected) { + auto expr = std::make_shared( + BOOLEAN(), + std::vector{ + field(columnFloatType, "c0"), + std::make_shared( + makeArrayVector({inlist})), + }, + "in"); + auto data = makeRowVector({ + makeFlatVector(input), + }); + auto expectedResults = makeFlatVector(expected); + auto result = evaluate(expr, data); + assertEqualVectors(expectedResults, result); + }; + + testInWithConstList({kNaN, kSNaN}, {kNaN, 1}, {true, true}); + testInWithConstList({kNaN, kSNaN}, {1, 2}, {false, false}); + // Need to specifically test in-list with a single element as it previously + // had a seperate codepath. + testInWithConstList({kNaN, kSNaN}, {kNaN}, {true, true}); + testInWithConstList({kNaN, kSNaN}, {1}, {false, false}); + + { + // Constant In-list, complex input(row). + // In-list is [row{kNaN, 1}]. + auto inlist = makeArrayVector( + {0}, + makeRowVector( + {makeFlatVector(std::vector({kNaN})), + makeFlatVector(std::vector({1}))})); + auto expr = std::make_shared( + BOOLEAN(), + std::vector{ + field(ROW({columnFloatType, INTEGER()}), "c0"), + std::make_shared(inlist), + }, + "in"); + // Input is [row{kNaN, 1}, row{kSNaN, 1}, row{kNaN, 2}]. + auto data = makeRowVector({makeRowVector( + {makeFlatVector(std::vector({kNaN, kSNaN, kNaN})), + makeFlatVector(std::vector({1, 1, 2}))})}); + auto expectedResults = makeFlatVector({true, true, false}); + auto result = evaluate(expr, data); + assertEqualVectors(expectedResults, result); + } + + { + // Variable In-list, primitive input. + auto data = makeRowVector({ + makeFlatVector({kNaN, kSNaN, kNaN}), + makeFlatVector({kNaN, kNaN, 0}), + makeFlatVector({1, 1, 1}), + }); + // Expression: c0 in (c1, c2) + auto inWithVariableInList = std::make_shared( + BOOLEAN(), + std::vector{ + field(columnFloatType, "c0"), + field(columnFloatType, "c1"), + field(columnFloatType, "c2"), + }, + "in"); + auto expectedResults = makeFlatVector({ + true, // kNaN in (kNaN, 1) + true, // kSNaN in (kNaN, 1) + false, // kNaN in (kNaN, 0) + }); + auto result = evaluate(inWithVariableInList, data); + assertEqualVectors(expectedResults, result); + } + + { + // Variable In-list, complex input(row). + // Input is: + // c0: [row{kNaN, 1}, row{kSNaN, 1}, row{kNaN, 2}] + // c1: [row{kNaN, 1}, row{kNaN, 1}, row{kNaN, 1}] + auto data = makeRowVector( + {makeRowVector( + {makeFlatVector(std::vector({kNaN, kSNaN, kNaN})), + makeFlatVector(std::vector({1, 1, 2}))}), + makeRowVector( + {makeFlatVector(std::vector({kNaN, kNaN, kNaN})), + makeFlatVector(std::vector({1, 1, 1}))})}); + // Expression: c0 in (c1) + auto inWithVariableInList = std::make_shared( + BOOLEAN(), + std::vector{ + field(ROW({columnFloatType, INTEGER()}), "c0"), + field(ROW({columnFloatType, INTEGER()}), "c1"), + }, + "in"); + auto expectedResults = makeFlatVector({true, true, false}); + auto result = evaluate(inWithVariableInList, data); + assertEqualVectors(expectedResults, result); + } + } }; TEST_F(InPredicateTest, bigint) { @@ -952,5 +1060,11 @@ TEST_F(InPredicateTest, nonConstantInList) { assertEqualVectors(expected, result); } +TEST_F(InPredicateTest, nans) { + // Ensure that NaNs with different bit patterns are treated as equal. + testNaNs(); + testNaNs(); +} + } // namespace } // namespace facebook::velox::functions