diff --git a/velox/functions/prestosql/Comparisons.cpp b/velox/functions/prestosql/Comparisons.cpp index 0f37bae010170..05eab2b80f569 100644 --- a/velox/functions/prestosql/Comparisons.cpp +++ b/velox/functions/prestosql/Comparisons.cpp @@ -21,6 +21,13 @@ namespace facebook::velox::functions { +using Eq = std::equal_to<>; +using Neq = std::not_equal_to<>; +using Lt = std::less<>; +using Lte = std::less_equal<>; +using Gt = std::greater<>; +using Gte = std::greater_equal<>; + namespace { /// This class implements comparison for vectors of primitive types using SIMD. @@ -94,6 +101,29 @@ struct SimdComparator { } } + template + inline bool compare(T& l, T& r) const { + if constexpr (std::is_floating_point_v) { + bool filtered = false; + if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareEquals{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = !util::floating_point::NaNAwareEquals{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareLessThan{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareLessThanEqual{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareGreaterThan{}(l, r); + } else if constexpr (std::is_same_v) { + filtered = util::floating_point::NaNAwareGreaterThanEqual{}(l, r); + } + return filtered; + } else { + return ComparisonOp()(l, r); + } + } + template < TypeKind kind, typename std::enable_if_t< @@ -113,24 +143,27 @@ struct SimdComparator { auto resultVector = result->asUnchecked>(); auto rawResult = resultVector->mutableRawValues(); - auto isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && + bool isSimdizable = (lhs.isConstantEncoding() || lhs.isFlatEncoding()) && (rhs.isConstantEncoding() || rhs.isFlatEncoding()) && rows.isAllSelected(); - if (!isSimdizable || std::is_same_v) { + static const bool isTypeNotSupported = + std::is_same_v || std::is_floating_point_v; + + if (!isSimdizable || isTypeNotSupported) { exec::LocalDecodedVector lhsDecoded(context, lhs, rows); exec::LocalDecodedVector rhsDecoded(context, rhs, rows); context.template applyToSelectedNoThrow(rows, [&](auto row) { auto l = lhsDecoded->template valueAt(row); auto r = rhsDecoded->template valueAt(row); - auto filtered = ComparisonOp()(l, r); + auto filtered = compare(l, r); resultVector->set(row, filtered); }); return; } - if constexpr (!std::is_same_v) { + if constexpr (!isTypeNotSupported) { if (lhs.isConstantEncoding() && rhs.isConstantEncoding()) { auto l = lhs.asUnchecked>()->valueAt(0); auto r = rhs.asUnchecked>()->valueAt(0); @@ -242,7 +275,7 @@ class ComparisonSimdFunction : public exec::VectorFunction { } exec::FunctionCanonicalName getCanonicalName() const override { - return std::is_same_v> + return std::is_same_v ? exec::FunctionCanonicalName::kLt : exec::FunctionCanonicalName::kUnknown; } @@ -252,32 +285,32 @@ class ComparisonSimdFunction : public exec::VectorFunction { VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_eq, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_neq, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_lt, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_gt, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_lte, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); VELOX_DECLARE_VECTOR_FUNCTION( udf_simd_comparison_gte, - (ComparisonSimdFunction>::signatures()), - (std::make_unique>>())); + (ComparisonSimdFunction::signatures()), + (std::make_unique>())); } // namespace facebook::velox::functions diff --git a/velox/functions/prestosql/Comparisons.h b/velox/functions/prestosql/Comparisons.h index c57b177c9f247..9550524063202 100644 --- a/velox/functions/prestosql/Comparisons.h +++ b/velox/functions/prestosql/Comparisons.h @@ -18,18 +18,23 @@ #include "velox/common/base/CompareFlags.h" #include "velox/functions/Macros.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" +#include "velox/type/FloatingPointUtil.h" namespace facebook::velox::functions { -#define VELOX_GEN_BINARY_EXPR(Name, Expr, TResult) \ - template \ - struct Name { \ - VELOX_DEFINE_FUNCTION_TYPES(T); \ - template \ - FOLLY_ALWAYS_INLINE void \ - call(TResult& result, const TInput& lhs, const TInput& rhs) { \ - result = (Expr); \ - } \ +#define VELOX_GEN_BINARY_EXPR(Name, Expr, ExprForFloats) \ + template \ + struct Name { \ + VELOX_DEFINE_FUNCTION_TYPES(T); \ + template \ + FOLLY_ALWAYS_INLINE void \ + call(bool& result, const TInput& lhs, const TInput& rhs) { \ + if constexpr (std::is_floating_point_v) { \ + result = (ExprForFloats); \ + return; \ + } \ + result = (Expr); \ + } \ }; #define VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE(Name, tsExpr, TResult) \ @@ -44,10 +49,22 @@ namespace facebook::velox::functions { } \ }; -VELOX_GEN_BINARY_EXPR(LtFunction, lhs < rhs, bool); -VELOX_GEN_BINARY_EXPR(GtFunction, lhs > rhs, bool); -VELOX_GEN_BINARY_EXPR(LteFunction, lhs <= rhs, bool); -VELOX_GEN_BINARY_EXPR(GteFunction, lhs >= rhs, bool); +VELOX_GEN_BINARY_EXPR( + LtFunction, + lhs < rhs, + util::floating_point::NaNAwareLessThan{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + GtFunction, + lhs > rhs, + util::floating_point::NaNAwareGreaterThan{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + LteFunction, + lhs <= rhs, + util::floating_point::NaNAwareLessThanEqual{}(lhs, rhs)); +VELOX_GEN_BINARY_EXPR( + GteFunction, + lhs >= rhs, + util::floating_point::NaNAwareGreaterThanEqual{}(lhs, rhs)); VELOX_GEN_BINARY_EXPR_TIMESTAMP_WITH_TIME_ZONE( LtFunction, @@ -99,6 +116,10 @@ struct EqFunction { // Used for primitive inputs. template void call(bool& out, const TInput& lhs, const TInput& rhs) { + if constexpr (std::is_floating_point_v) { + out = util::floating_point::NaNAwareEquals{}(lhs, rhs); + return; + } out = (lhs == rhs); } @@ -138,6 +159,10 @@ struct NeqFunction { // Used for primitive inputs. template void call(bool& out, const TInput& lhs, const TInput& rhs) { + if constexpr (std::is_floating_point_v) { + out = !util::floating_point::NaNAwareEquals{}(lhs, rhs); + return; + } out = (lhs != rhs); } @@ -172,6 +197,12 @@ struct BetweenFunction { template FOLLY_ALWAYS_INLINE void call(bool& result, const T& value, const T& low, const T& high) { + if constexpr (std::is_floating_point_v) { + result = + util::floating_point::NaNAwareGreaterThanEqual{}(value, low) && + util::floating_point::NaNAwareLessThanEqual{}(value, high); + return; + } result = value >= low && value <= high; } }; diff --git a/velox/functions/prestosql/tests/ComparisonsTest.cpp b/velox/functions/prestosql/tests/ComparisonsTest.cpp index b1228a02096b1..8dc53d3bc6645 100644 --- a/velox/functions/prestosql/tests/ComparisonsTest.cpp +++ b/velox/functions/prestosql/tests/ComparisonsTest.cpp @@ -16,6 +16,8 @@ #include #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/Udf.h" +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/prestosql/Comparisons.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" using namespace facebook::velox; @@ -34,6 +36,16 @@ class ComparisonsTest : public functions::test::FunctionBaseTest { auto actual = evaluate(exprStr, makeRowVector(input)); test::assertEqualVectors(expectedResult, actual); } + + void registerSimpleComparisonFunctions() { + using namespace facebook::velox::functions; + registerBinaryScalar({"simple_eq"}); + registerBinaryScalar({"simple_neq"}); + registerBinaryScalar({"simple_lt"}); + registerBinaryScalar({"simple_lte"}); + registerBinaryScalar({"simple_gt"}); + registerBinaryScalar({"simple_gte"}); + } }; TEST_F(ComparisonsTest, between) { @@ -641,6 +653,91 @@ TEST_F(ComparisonsTest, overflowTest) { } } +TEST_F(ComparisonsTest, nanComparison) { + registerSimpleComparisonFunctions(); + static const auto kNaN = std::numeric_limits::quiet_NaN(); + static const auto kSNaN = std::numeric_limits::signaling_NaN(); + static const auto kInf = std::numeric_limits::infinity(); + + auto testNaN = + [&](std::string prefix, RowVectorPtr rowVector, bool primitiveInput) { + auto eval = [&](const std::string& expr, + const std::string& lhs, + const std::string& rhs) { + return evaluate>( + fmt::format("{}({}, {})", expr, lhs, rhs), rowVector); + }; + + auto allFalse = makeFlatVector({false, false}); + auto allTrue = makeFlatVector({true, true}); + + // NaN compared with NaN (multiple binary representations) + test::assertEqualVectors(eval(prefix + "eq", "c0", "c1"), allTrue); + test::assertEqualVectors(eval(prefix + "neq", "c0", "c1"), allFalse); + if (primitiveInput) { + test::assertEqualVectors(eval(prefix + "lt", "c0", "c1"), allFalse); + test::assertEqualVectors(eval(prefix + "gt", "c0", "c1"), allFalse); + test::assertEqualVectors(eval(prefix + "lte", "c0", "c1"), allTrue); + test::assertEqualVectors(eval(prefix + "gte", "c0", "c1"), allTrue); + // NaN between Infinity and NaN + test::assertEqualVectors( + evaluate>("c0 BETWEEN c2 and c1", rowVector), + allTrue); + // NaN distinct from NaN + test::assertEqualVectors( + evaluate>("c0 IS DISTINCT FROM c1", rowVector), + allFalse); + } + + // NaN compared with Inf + test::assertEqualVectors(eval(prefix + "eq", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "neq", "c0", "c2"), allTrue); + if (primitiveInput) { + test::assertEqualVectors(eval(prefix + "lt", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "gt", "c0", "c2"), allTrue); + test::assertEqualVectors(eval(prefix + "lte", "c0", "c2"), allFalse); + test::assertEqualVectors(eval(prefix + "gte", "c0", "c2"), allTrue); + // NaN between 0 and Infinity + test::assertEqualVectors( + evaluate>( + "c0 BETWEEN cast(0 as double) and c2", rowVector), + allFalse); + // NaN distinct from Infinity + test::assertEqualVectors( + evaluate>("c0 IS DISTINCT FROM c2", rowVector), + allTrue); + } + }; + + // Primitive type input + auto input = makeRowVector( + {makeFlatVector({kNaN, kSNaN}), + makeFlatVector({kNaN, kNaN}), + makeFlatVector({kInf, kInf})}); + // Test the Vector function ComparisonSimdFunction. + testNaN("", input, true); + // Test the Simple functions. + testNaN("simple_", input, true); + + // Complex type input + input = makeRowVector({ + makeRowVector({ + makeFlatVector({kNaN, kSNaN}), + makeFlatVector({1, 1}), + }), + makeRowVector({ + makeFlatVector({kNaN, kNaN}), + makeFlatVector({1, 1}), + }), + makeRowVector({ + makeFlatVector({kInf, kInf}), + makeFlatVector({1, 1}), + }), + }); + // Note: Complex comparison functions are only registered as simple functions. + testNaN("", input, false); +} + namespace { template struct ComparisonTypeOp { @@ -766,29 +863,6 @@ class SimdComparisonsTest : public functions::test::FunctionBaseTest { rhsVector.begin(), rhsVector.end(), std::numeric_limits::min()); testVectorComparison(lhsVector, rhsVector); - - // Add tests against Nan and other edge cases. - if constexpr (std::is_floating_point_v) { - lhsVector = std::vector(47); - rhsVector = std::vector(47); - - std::fill( - lhsVector.begin(), - lhsVector.end(), - std::numeric_limits::signaling_NaN()); - std::fill( - rhsVector.begin(), - rhsVector.end(), - std::numeric_limits::signaling_NaN()); - testVectorComparison(lhsVector, rhsVector); - - std::fill( - lhsVector.begin(), - lhsVector.end(), - std::numeric_limits::signaling_NaN()); - std::fill(rhsVector.begin(), rhsVector.end(), 1); - testVectorComparison(lhsVector, rhsVector); - } } void testDictionary() { diff --git a/velox/type/FloatingPointUtil.h b/velox/type/FloatingPointUtil.h index 298537ebcd4fc..082d5d8c6c5ea 100644 --- a/velox/type/FloatingPointUtil.h +++ b/velox/type/FloatingPointUtil.h @@ -58,6 +58,18 @@ struct NaNAwareLessThan { } }; +template < + typename FLOAT, + std::enable_if_t::value, bool> = true> +struct NaNAwareLessThanEqual { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (std::isnan(rhs)) { + return true; + } + return lhs <= rhs; + } +}; + template < typename FLOAT, std::enable_if_t::value, bool> = true> @@ -70,6 +82,18 @@ struct NaNAwareGreaterThan { } }; +template < + typename FLOAT, + std::enable_if_t::value, bool> = true> +struct NaNAwareGreaterThanEqual { + bool operator()(const FLOAT& lhs, const FLOAT& rhs) const { + if (std::isnan(lhs)) { + return true; + } + return lhs >= rhs; + } +}; + template < typename FLOAT, std::enable_if_t::value, bool> = true> diff --git a/velox/type/tests/FloatingPointUtilTest.cpp b/velox/type/tests/FloatingPointUtilTest.cpp index e1139e96e1ea0..c603d18d0e5dd 100644 --- a/velox/type/tests/FloatingPointUtilTest.cpp +++ b/velox/type/tests/FloatingPointUtilTest.cpp @@ -39,11 +39,25 @@ void testFloatingPoint() { ASSERT_FALSE(NaNAwareLessThan{}(kNaN, kInf)); ASSERT_TRUE(NaNAwareLessThan{}(kInf, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kNaN, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kNaN, kSNAN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(kInf, kNaN)); + ASSERT_TRUE(NaNAwareLessThanEqual{}(0.0, kInf)); + ASSERT_FALSE(NaNAwareLessThanEqual{}(kNaN, kInf)); + ASSERT_FALSE(NaNAwareLessThanEqual{}(kNaN, 0.0)); + ASSERT_FALSE(NaNAwareGreaterThan{}(kNaN, kNaN)); ASSERT_FALSE(NaNAwareGreaterThan{}(kNaN, kSNAN)); ASSERT_FALSE(NaNAwareGreaterThan{}(kInf, kNaN)); ASSERT_TRUE(NaNAwareGreaterThan{}(kNaN, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kNaN)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kSNAN)); + ASSERT_FALSE(NaNAwareGreaterThanEqual{}(kInf, kNaN)); + ASSERT_FALSE(NaNAwareGreaterThanEqual{}(0.0, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, kInf)); + ASSERT_TRUE(NaNAwareGreaterThanEqual{}(kNaN, 0.0)); + ASSERT_EQ(NaNAwareHash{}(kNaN), NaNAwareHash{}(kSNAN)); ASSERT_EQ(NaNAwareHash{}(kNaN), NaNAwareHash{}(kNaN)); ASSERT_EQ(NaNAwareHash{}(0.0), NaNAwareHash{}(0.0));