From f32385af2660be6a705121c65ce61dae31fba8e7 Mon Sep 17 00:00:00 2001 From: Bikramjeet Vig Date: Tue, 2 Jul 2024 17:38:10 -0700 Subject: [PATCH] Fix NaN handling in comparison functions (#10165) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/10165 This change ensures that comparison functions treat all NaN bit representations as equal and greater than infinity. Summary of changes: Currently, numerical primitive types, including floating points, have an optimized SIMD implementation. However, this relies on low-level instructions that follow IEEE floating point semantics, where NaN compared to any other number, including itself, returns false. There is no trivial way to implement NaN checking and handling using SIMD, so for now, we are falling back to the regular non-SIMD version that employs comparators supporting NaN handling in the intended way. Additionally, this change currently handles this by adding a special case in the ComparisonSimdFunction vector function instead of registering the SimpleFunction for floating types. This is because registering an additional/different function to handle floating types is causing a regression in an internal integration which is still under investigation. Reviewed By: kgpai Differential Revision: D58471982 --- velox/functions/prestosql/Comparisons.cpp | 67 +++++++--- velox/functions/prestosql/Comparisons.h | 57 +++++++-- .../prestosql/tests/ComparisonsTest.cpp | 120 ++++++++++++++---- velox/type/FloatingPointUtil.h | 24 ++++ velox/type/tests/FloatingPointUtilTest.cpp | 14 ++ 5 files changed, 229 insertions(+), 53 deletions(-) diff --git a/velox/functions/prestosql/Comparisons.cpp b/velox/functions/prestosql/Comparisons.cpp index 0f37bae010170..05eab2b80f569 100644 --- a/velox/functions/prestosql/Comparisons.cpp +++ b/velox/functions/prestosql/Comparisons.cpp @@ -21,6 +21,13 @@ namespace facebook::velox::functions { +using Eq = std::equal_to<>; +using Neq = std::not_equal_to<>; +using Lt = std::less<>; +using Lte = std::less_equal<>; +using Gt = std::greater<>; +using Gte = std::greater_equal<>; + namespace { /// This class implements comparison for vectors of primitive types using SIMD. @@ -94,6 +101,29 @@ struct SimdComparator { } } + template + inline bool compare(T& l, T& r) const { + if constexpr (std::is_floating_point_v) { + bool filtered = false; + if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareEquals{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = !util::floating_point::NaNAwareEquals{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareLessThan{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareLessThanEqual{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareGreaterThan{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareGreaterThanEqual{}(l, r); + } + return filtered; + } else { + return ComparisonOp()(l, r); + } + } + template < TypeKind kind, typename std::enable_if_t< @@ -113,24 +143,27 @@ struct SimdComparator { auto resultVector = result->asUnchecked>(); auto rawResult = resultVector->mutableRawValues(); - auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && + bool isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && (rhs.isConstantEncoding() || rhs.isFlatEncoding()) && rows.isAllSelected(); - if (!isSimdizable || std::is_same_v) { + static const bool isTypeNotSupported = + std::is_same_v || std::is_floating_point_v; + + if (!isSimdizable || isTypeNotSupported) { exec::LocalDecodedVector lhsDecoded(context, lhs, rows); exec::LocalDecodedVector rhsDecoded(context, rhs, rows); context.template applyToSelectedNoThrow(rows, [&](auto row) { auto l = lhsDecoded->template valueAt(row); auto r = rhsDecoded->template valueAt(row); - auto filtered = ComparisonOp()(l, r); + auto filtered = compare(l, r); resultVector->set(row, filtered); }); return; } - if constexpr (!std::is_same_v) { + if constexpr (!isTypeNotSupported) { if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { auto l = lhs.asUnchecked>()->valueAt(0); auto r = rhs.asUnchecked>()->valueAt(0); @@ -242,7 +275,7 @@ class ComparisonSimdFunction : public exec::VectorFunction { } exec::FunctionCanonicalName getCanonicalName() const override { - return std::is_same_v> + return std::is_same_v ? exec::FunctionCanonicalName::kLt : exec::FunctionCanonicalName::kUnknown; } @@ -252,32 +285,32 @@ class ComparisonSimdFunction : public exec::VectorFunction { VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_eq, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_neq, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_lt, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_gt, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_lte, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_gte, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/Comparisons.h b/velox/functions/prestosql/Comparisons.h index c57b177c9f247..9550524063202 100644 --- a/velox/functions/prestosql/Comparisons.h +++ b/velox/functions/prestosql/Comparisons.h @@ -18,18 +18,23 @@ #include "velox/common/base/CompareFlags.h" #include "velox/functions/Macros.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { -#define VELOX_GEN_BINARY_EXPR(Name, Expr, TResult) \ - template \ - struct Name { \ - VELOX_DEFINE_FUNCTION_TYPES(T); \ - template \ - FOLLY_ALWAYS_INLINE void \ - call(TResult& result, const TInput& lhs, const TInput& rhs) { \ - result = (Expr); \ - } \ +#define VELOX_GEN_BINARY_EXPR(Name, Expr, ExprForFloats) \ + template \ + struct Name { \ + VELOX_DEFINE_FUNCTION_TYPES(T); \ + template \ + FOLLY_ALWAYS_INLINE void \ + call(bool& result, const TInput& lhs, const TInput& rhs) { \ + if constexpr (std::is_floating_point_v) { \ + result = (ExprForFloats); \ + return; \ + } \ + result = (Expr); \ + } \ }; #define VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(Name, tsExpr, TResult) \ @@ -44,10 +49,22 @@ namespace facebook::velox::functions { } \ }; -VELOX_GEN_BINARY_EXPR(LtFunction, lhs < rhs, bool); -VELOX_GEN_BINARY_EXPR(GtFunction, lhs > rhs, bool); -VELOX_GEN_BINARY_EXPR(LteFunction, lhs <= rhs, bool); -VELOX_GEN_BINARY_EXPR(GteFunction, lhs >= rhs, bool); +VELOX_GEN_BINARY_EXPR( + LtFunction, + lhs < rhs, + util::floating_point::NaNAwareLessThan{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + GtFunction, + lhs > rhs, + util::floating_point::NaNAwareGreaterThan{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + LteFunction, + lhs <= rhs, + util::floating_point::NaNAwareLessThanEqual{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + GteFunction, + lhs >= rhs, + util::floating_point::NaNAwareGreaterThanEqual{}(lhs, rhs)); VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE( LtFunction, @@ -99,6 +116,10 @@ struct EqFunction { // Used for primitive inputs. template void call(bool& out, const TInput& lhs, const TInput& rhs) { + if constexpr (std::is_floating_point_v) { + out = util::floating_point::NaNAwareEquals{}(lhs, rhs); + return; + } out = (lhs == rhs); } @@ -138,6 +159,10 @@ struct NeqFunction { // Used for primitive inputs. template void call(bool& out, const TInput& lhs, const TInput& rhs) { + if constexpr (std::is_floating_point_v) { + out = !util::floating_point::NaNAwareEquals{}(lhs, rhs); + return; + } out = (lhs != rhs); } @@ -172,6 +197,12 @@ struct BetweenFunction { template FOLLY_ALWAYS_INLINE void call(bool& result, const T& value, const T& low, const T& high) { + if constexpr (std::is_floating_point_v) { + result = + util::floating_point::NaNAwareGreaterThanEqual{}(value, low) && + util::floating_point::NaNAwareLessThanEqual{}(value, high); + return; + } result = value >= low && value <= high; } }; diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index b1228a02096b1..8dc53d3bc6645 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -16,6 +16,8 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/Udf.h" +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/prestosql/Comparisons.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" using namespace facebook::velox; @@ -34,6 +36,16 @@ class ComparisonsTest : public functions::test::FunctionBaseTest { auto actual = evaluate(exprStr, makeRowVector(input)); test::assertEqualVectors(expectedResult, actual); } + + void registerSimpleComparisonFunctions() { + using namespace facebook::velox::functions; + registerBinaryScalar({"simple_eq"}); + registerBinaryScalar({"simple_neq"}); + registerBinaryScalar({"simple_lt"}); + registerBinaryScalar({"simple_lte"}); + registerBinaryScalar({"simple_gt"}); + registerBinaryScalar({"simple_gte"}); + } }; TEST_F(ComparisonsTest, between) { @@ -641,6 +653,91 @@ TEST_F(ComparisonsTest, overflowTest) { } } +TEST_F(ComparisonsTest, nanComparison) { + registerSimpleComparisonFunctions(); + static const auto kNaN = std::numeric_limits::quiet_NaN(); + static const auto kSNaN = std::numeric_limits::signaling_NaN(); + static const auto kInf = std::numeric_limits::infinity(); + + auto testNaN = + [&](std::string prefix, RowVectorPtr rowVector, bool primitiveInput) { + auto eval = [&](const std::string& expr, + const std::string& lhs, + const std::string& rhs) { + return evaluate>( + fmt::format("{}({}, {})", expr, lhs, rhs), rowVector); + }; + + auto allFalse = makeFlatVector({false, false}); + auto allTrue = makeFlatVector({true, true}); + + // NaN compared with NaN (multiple binary representations) + test::assertEqualVectors(eval(prefix + "eq", "c0", "c1"), allTrue); + test::assertEqualVectors(eval(prefix + "neq", "c0", "c1"), allFalse); + if (primitiveInput) { + test::assertEqualVectors(eval(prefix + "lt", "c0", "c1"), allFalse); + test::assertEqualVectors(eval(prefix + "gt", "c0", "c1"), allFalse); + test::assertEqualVectors(eval(prefix + "lte", "c0", "c1"), allTrue); + test::assertEqualVectors(eval(prefix + "gte", "c0", "c1"), allTrue); + // NaN between Infinity and NaN + test::assertEqualVectors( + evaluate>("c0 BETWEEN c2 and c1", rowVector), + allTrue); + // NaN distinct from NaN + test::assertEqualVectors( + evaluate>("c0 IS DISTINCT FROM c1", rowVector), + allFalse); + } + + // NaN compared with Inf + test::assertEqualVectors(eval(prefix + "eq", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "neq", "c0", "c2"), allTrue); + if (primitiveInput) { + test::assertEqualVectors(eval(prefix + "lt", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "gt", "c0", "c2"), allTrue); + test::assertEqualVectors(eval(prefix + "lte", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "gte", "c0", "c2"), allTrue); + // NaN between 0 and Infinity + test::assertEqualVectors( + evaluate>( + "c0 BETWEEN cast(0 as double) and c2", rowVector), + allFalse); + // NaN distinct from Infinity + test::assertEqualVectors( + evaluate>("c0 IS DISTINCT FROM c2", rowVector), + allTrue); + } + }; + + // Primitive type input + auto input = makeRowVector( + {makeFlatVector({kNaN, kSNaN}), + makeFlatVector({kNaN, kNaN}), + makeFlatVector({kInf, kInf})}); + // Test the Vector function ComparisonSimdFunction. + testNaN("", input, true); + // Test the Simple functions. + testNaN("simple_", input, true); + + // Complex type input + input = makeRowVector({ + makeRowVector({ + makeFlatVector({kNaN, kSNaN}), + makeFlatVector({1, 1}), + }), + makeRowVector({ + makeFlatVector({kNaN, kNaN}), + makeFlatVector({1, 1}), + }), + makeRowVector({ + makeFlatVector({kInf, kInf}), + makeFlatVector({1, 1}), + }), + }); + // Note: Complex comparison functions are only registered as simple functions. + testNaN("", input, false); +} + namespace { template struct ComparisonTypeOp { @@ -766,29 +863,6 @@ class SimdComparisonsTest : public functions::test::FunctionBaseTest { rhsVector.begin(), rhsVector.end(), std::numeric_limits::min()); testVectorComparison(lhsVector, rhsVector); - - // Add tests against Nan and other edge cases. - if constexpr (std::is_floating_point_v) { - lhsVector = std::vector(47); - rhsVector = std::vector(47); - - std::fill( - lhsVector.begin(), - lhsVector.end(), - std::numeric_limits::signaling_NaN()); - std::fill( - rhsVector.begin(), - rhsVector.end(), - std::numeric_limits::signaling_NaN()); - testVectorComparison(lhsVector, rhsVector); - - std::fill( - lhsVector.begin(), - lhsVector.end(), - std::numeric_limits::signaling_NaN()); - std::fill(rhsVector.begin(), rhsVector.end(), 1); - testVectorComparison(lhsVector, rhsVector); - } } void testDictionary() { diff --git a/velox/type/FloatingPointUtil.h b/velox/type/FloatingPointUtil.h index 298537ebcd4fc..082d5d8c6c5ea 100644 --- a/velox/type/FloatingPointUtil.h +++ b/velox/type/FloatingPointUtil.h @@ -58,6 +58,18 @@ struct NaNAwareLessThan { } }; +template < + typename FLOAT, + std::enable_if_t::value, bool> = true> +struct NaNAwareLessThanEqual { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (std::isnan(rhs)) { + return true; + } + return lhs <= rhs; + } +}; + template < typename FLOAT, std::enable_if_t::value, bool> = true> @@ -70,6 +82,18 @@ struct NaNAwareGreaterThan { } }; +template < + typename FLOAT, + std::enable_if_t::value, bool> = true> +struct NaNAwareGreaterThanEqual { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (std::isnan(lhs)) { + return true; + } + return lhs >= rhs; + } +}; + template < typename FLOAT, std::enable_if_t::value, bool> = true> diff --git a/velox/type/tests/FloatingPointUtilTest.cpp b/velox/type/tests/FloatingPointUtilTest.cpp index e1139e96e1ea0..c603d18d0e5dd 100644 --- a/velox/type/tests/FloatingPointUtilTest.cpp +++ b/velox/type/tests/FloatingPointUtilTest.cpp @@ -39,11 +39,25 @@ void testFloatingPoint() { ASSERT_FALSE(NaNAwareLessThan{}(kNaN, kInf)); ASSERT_TRUE(NaNAwareLessThan{}(kInf, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kNaN, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kNaN, kSNAN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kInf, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(0.0, kInf)); + ASSERT_FALSE(NaNAwareLessThanEqual{}(kNaN, kInf)); + ASSERT_FALSE(NaNAwareLessThanEqual{}(kNaN, 0.0)); + ASSERT_FALSE(NaNAwareGreaterThan{}(kNaN, kNaN)); ASSERT_FALSE(NaNAwareGreaterThan{}(kNaN, kSNAN)); ASSERT_FALSE(NaNAwareGreaterThan{}(kInf, kNaN)); ASSERT_TRUE(NaNAwareGreaterThan{}(kNaN, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kNaN)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kSNAN)); + ASSERT_FALSE(NaNAwareGreaterThanEqual{}(kInf, kNaN)); + ASSERT_FALSE(NaNAwareGreaterThanEqual{}(0.0, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, 0.0)); + ASSERT_EQ(NaNAwareHash{}(kNaN), NaNAwareHash{}(kSNAN)); ASSERT_EQ(NaNAwareHash{}(kNaN), NaNAwareHash{}(kNaN)); ASSERT_EQ(NaNAwareHash{}(0.0), NaNAwareHash{}(0.0));