From 34c093e582b8c861497b039e09b91827f0945425 Mon Sep 17 00:00:00 2001 From: Pramod Satya Date: Sun, 17 Nov 2024 10:04:12 +0530 Subject: [PATCH] feat: Add mathematical operators for IntervalYearMonth type --- velox/functions/prestosql/Arithmetic.h | 57 +++-- .../MathematicalOperatorsRegistration.cpp | 35 +++ .../prestosql/tests/ArithmeticTest.cpp | 204 +++++++++--------- 3 files changed, 185 insertions(+), 111 deletions(-) diff --git a/velox/functions/prestosql/Arithmetic.h b/velox/functions/prestosql/Arithmetic.h index ff2eafc24dcb7..e86fc5b9a9b06 100644 --- a/velox/functions/prestosql/Arithmetic.h +++ b/velox/functions/prestosql/Arithmetic.h @@ -38,6 +38,8 @@ inline constexpr int kMinRadix = 2; inline constexpr int kMaxRadix = 36; inline constexpr long kLongMax = std::numeric_limits::max(); inline constexpr long kLongMin = std::numeric_limits::min(); +inline constexpr long kIntegerMax = std::numeric_limits::max(); +inline constexpr long kIntegerMin = std::numeric_limits::min(); inline constexpr char digits[36] = { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', @@ -73,7 +75,8 @@ struct MultiplyFunction { } }; -// Multiply function for IntervalDayTime * Double and Double * IntervalDayTime. +// Multiply function for IntervalDayTime * Double, Double * IntervalDayTime, +// IntervalYearMonth * Double and Double * IntervalYearMonth. template struct IntervalMultiplyFunction { FOLLY_ALWAYS_INLINE double sanitizeInput(double d) { @@ -84,12 +87,15 @@ struct IntervalMultiplyFunction { } template < + typename TResult, typename T1, typename T2, typename = std::enable_if_t< (std::is_same_v && std::is_same_v) || - (std::is_same_v && std::is_same_v)>> - FOLLY_ALWAYS_INLINE void call(int64_t& result, T1 a, T2 b) { + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v)>> + FOLLY_ALWAYS_INLINE void call(TResult& result, T1 a, T2 b) { double resultDouble; if constexpr (std::is_same_v) { resultDouble = sanitizeInput(a) * b; @@ -97,12 +103,23 @@ struct IntervalMultiplyFunction { resultDouble = sanitizeInput(b) * a; } + TResult min, max, maxCheck; + if constexpr (std::is_same_v) { + min = kLongMin; + max = kLongMax; + maxCheck = kMaxDoubleBelowInt64Max; + } else { + min = kIntegerMin; + max = kIntegerMax; + maxCheck = kIntegerMax; + } + if LIKELY ( - std::isfinite(resultDouble) && resultDouble >= kLongMin && - resultDouble <= kMaxDoubleBelowInt64Max) { - result = int64_t(resultDouble); + std::isfinite(resultDouble) && resultDouble >= min && + resultDouble <= maxCheck) { + result = TResult(resultDouble); } else { - result = resultDouble > 0 ? kLongMax : kLongMin; + result = resultDouble > 0 ? max : min; } } }; @@ -123,9 +140,14 @@ struct DivideFunction { } }; +// Divide function for IntervalDayTime / Double and IntervalYearMonth / Double. template struct IntervalDivideFunction { - FOLLY_ALWAYS_INLINE void call(int64_t& result, int64_t a, double b) + template < + typename TResult, + typename = std::enable_if_t< + std::is_same_v || std::is_same_v>> + FOLLY_ALWAYS_INLINE void call(TResult& result, TResult a, double b) // Depend on compiler have correct behaviour for divide by zero #if defined(__has_feature) #if __has_feature(__address_sanitizer__) @@ -134,17 +156,28 @@ struct IntervalDivideFunction { #endif #endif { + TResult min, max, maxCheck; + if constexpr (std::is_same_v) { + min = kLongMin; + max = kLongMax; + maxCheck = kMaxDoubleBelowInt64Max; + } else { + min = kIntegerMin; + max = kIntegerMax; + maxCheck = kIntegerMax; + } + if UNLIKELY (a == 0 || std::isnan(b) || !std::isfinite(b)) { result = 0; return; } double resultDouble = a / b; if LIKELY ( - std::isfinite(resultDouble) && resultDouble >= kLongMin && - resultDouble <= kMaxDoubleBelowInt64Max) { - result = int64_t(resultDouble); + std::isfinite(resultDouble) && resultDouble >= min && + resultDouble <= maxCheck) { + result = TResult(resultDouble); } else { - result = resultDouble > 0 ? kLongMax : kLongMin; + result = resultDouble > 0 ? max : min; } } }; diff --git a/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp b/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp index f5d806589bd35..bab24264a6e0e 100644 --- a/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp +++ b/velox/functions/prestosql/registration/MathematicalOperatorsRegistration.cpp @@ -28,12 +28,22 @@ void registerMathOperators(const std::string& prefix = "") { IntervalDayTime, IntervalDayTime, IntervalDayTime>({prefix + "plus"}); + registerFunction< + PlusFunction, + IntervalYearMonth, + IntervalYearMonth, + IntervalYearMonth>({prefix + "plus"}); registerBinaryFloatingPoint({prefix + "minus"}); registerFunction< MinusFunction, IntervalDayTime, IntervalDayTime, IntervalDayTime>({prefix + "minus"}); + registerFunction< + MinusFunction, + IntervalYearMonth, + IntervalYearMonth, + IntervalYearMonth>({prefix + "minus"}); registerBinaryFloatingPoint({prefix + "multiply"}); registerFunction( {prefix + "multiply"}); @@ -49,12 +59,37 @@ void registerMathOperators(const std::string& prefix = "") { IntervalDayTime, double, IntervalDayTime>({prefix + "multiply"}); + registerFunction< + MultiplyFunction, + IntervalYearMonth, + IntervalYearMonth, + int32_t>({prefix + "multiply"}); + registerFunction< + MultiplyFunction, + IntervalYearMonth, + int32_t, + IntervalYearMonth>({prefix + "multiply"}); + registerFunction< + IntervalMultiplyFunction, + IntervalYearMonth, + IntervalYearMonth, + double>({prefix + "multiply"}); + registerFunction< + IntervalMultiplyFunction, + IntervalYearMonth, + double, + IntervalYearMonth>({prefix + "multiply"}); registerBinaryFloatingPoint({prefix + "divide"}); registerFunction< IntervalDivideFunction, IntervalDayTime, IntervalDayTime, double>({prefix + "divide"}); + registerFunction< + IntervalDivideFunction, + IntervalYearMonth, + IntervalYearMonth, + double>({prefix + "divide"}); registerBinaryFloatingPoint({prefix + "mod"}); } diff --git a/velox/functions/prestosql/tests/ArithmeticTest.cpp b/velox/functions/prestosql/tests/ArithmeticTest.cpp index 08e1d718dba2a..855e90d7a181f 100644 --- a/velox/functions/prestosql/tests/ArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/ArithmeticTest.cpp @@ -31,8 +31,6 @@ constexpr double kInf = std::numeric_limits::infinity(); constexpr double kNan = std::numeric_limits::quiet_NaN(); constexpr float kInfF = std::numeric_limits::infinity(); constexpr float kNanF = std::numeric_limits::quiet_NaN(); -constexpr int64_t kLongMax = std::numeric_limits::max(); -constexpr int64_t kLongMin = std::numeric_limits::min(); MATCHER(IsNan, "is NaN") { return arg && std::isnan(*arg); @@ -90,30 +88,109 @@ class ArithmeticTest : public functions::test::FunctionBaseTest { std::string(e.what()).find(errorMessage) != std::string::npos); } } + + template + void testIntervalPlus(const TypePtr& intervalType) { + T max = std::numeric_limits::max(); + T min = std::numeric_limits::min(); + auto op1 = makeNullableFlatVector( + {-1, 2, -3, 4, max, -1, std::nullopt, 0}, intervalType); + auto op2 = makeNullableFlatVector( + {2, -3, -1, 1, 1, min, 0, std::nullopt}, intervalType); + auto expected = makeNullableFlatVector( + {1, -1, -4, 5, min, max, std::nullopt, std::nullopt}, intervalType); + assertExpression("c0 + c1", op1, op2, expected); + } + + template + void testIntervalMinus(const TypePtr& intervalType) { + T max = std::numeric_limits::max(); + T min = std::numeric_limits::min(); + auto op1 = makeNullableFlatVector( + {-1, 2, -3, 4, min, -1, std::nullopt, 0}, intervalType); + auto op2 = makeNullableFlatVector( + {2, 3, -4, 1, 1, max, 0, std::nullopt}, intervalType); + auto expected = makeNullableFlatVector( + {-3, -1, 1, 3, max, min, std::nullopt, std::nullopt}, intervalType); + assertExpression("c0 - c1", op1, op2, expected); + } + + template + void testIntervalDivide(const TypePtr& intervalType) { + auto op1 = makeNullableFlatVector( + {3, 6, 9, std::nullopt, 12, 15, 18, 21, 0, 0, 1}, intervalType); + auto op2 = makeNullableFlatVector( + {0.5, + -2.0, + 5.0, + 1.0, + std::nullopt, + kNan, + kInf, + -kInf, + 0.0, + -0.0, + 0.00000001}); + auto expected = makeNullableFlatVector( + {6, -3, 1, std::nullopt, std::nullopt, 0, 0, 0, 0, 0, 100000000}, + intervalType); + assertExpression("c0 / c1", op1, op2, expected); + + T max = std::numeric_limits::max(); + T min = std::numeric_limits::min(); + op1 = makeFlatVector({1, 1, 1, 1, max, min, max, min}, intervalType); + op2 = makeFlatVector( + {0.0, -0.0, 4.9e-324, -4.9e-324, 0.1, 0.1, -0.1, -0.1}); + expected = makeFlatVector( + {max, min, max, min, max, min, min, max}, intervalType); + assertExpression("c0 / c1", op1, op2, expected); + } + + template + void testIntervalMultiply(const TypePtr& intervalType) { + T max = std::numeric_limits::max(); + T min = std::numeric_limits::min(); + auto op1 = makeNullableFlatVector( + {1, 2, 3, std::nullopt, 10, 20}, intervalType); + auto op2 = makeNullableFlatVector( + {1, std::nullopt, 3, 4, max, min}, CppToType::create()); + auto expected = makeNullableFlatVector( + {1, std::nullopt, 9, std::nullopt, -10, 0}, intervalType); + assertExpression("c0 * c1", op1, op2, expected); + assertExpression("c1 * c0", op1, op2, expected); + + op1 = makeNullableFlatVector( + {1, 2, 3, std::nullopt, 10, 20, 30, 40, 1000, 1000, 1000, 1000}, + intervalType); + auto doubleOp = makeNullableFlatVector( + {-1.8, + 2.1, + kNan, + 4.2, + 0.0, + -0.0, + kInf, + -kInf, + 9223372036854775807.01, + -9223372036854775808.01, + 1.7e308, + -1.7e308}); + expected = makeNullableFlatVector( + {-1, 4, 0, std::nullopt, 0, 0, max, min, max, min, max, min}, + intervalType); + assertExpression("c0 * c1", op1, doubleOp, expected); + assertExpression("c1 * c0", op1, doubleOp, expected); + } }; TEST_F(ArithmeticTest, plus) { - // Test plus for intervals. - auto op1 = makeNullableFlatVector( - {-1, 2, -3, 4, kLongMax, -1, std::nullopt, 0}, INTERVAL_DAY_TIME()); - auto op2 = makeNullableFlatVector( - {2, -3, -1, 1, 1, kLongMin, 0, std::nullopt}, INTERVAL_DAY_TIME()); - auto expected = makeNullableFlatVector( - {1, -1, -4, 5, kLongMin, kLongMax, std::nullopt, std::nullopt}, - INTERVAL_DAY_TIME()); - assertExpression("c0 + c1", op1, op2, expected); + testIntervalPlus(INTERVAL_DAY_TIME()); + testIntervalPlus(INTERVAL_YEAR_MONTH()); } TEST_F(ArithmeticTest, minus) { - // Test plus for intervals. - auto op1 = makeNullableFlatVector( - {-1, 2, -3, 4, kLongMin, -1, std::nullopt, 0}, INTERVAL_DAY_TIME()); - auto op2 = makeNullableFlatVector( - {2, 3, -4, 1, 1, kLongMax, 0, std::nullopt}, INTERVAL_DAY_TIME()); - auto expected = makeNullableFlatVector( - {-3, -1, 1, 3, kLongMax, kLongMin, std::nullopt, std::nullopt}, - INTERVAL_DAY_TIME()); - assertExpression("c0 - c1", op1, op2, expected); + testIntervalMinus(INTERVAL_DAY_TIME()); + testIntervalMinus(INTERVAL_YEAR_MONTH()); } TEST_F(ArithmeticTest, divide) @@ -144,42 +221,9 @@ __attribute__((__no_sanitize__("float-divide-by-zero"))) assertExpression( "c0 / c1", {10.5, 9.2, 0.0, 0.0}, {2, 0, 0, -1}, {5.25, kInf, kNan, 0.0}); - // Test interval divided by double. - auto intervalVector = makeNullableFlatVector( - {3, 6, 9, std::nullopt, 12, 15, 18, 21, 0, 0, 1}, INTERVAL_DAY_TIME()); - auto doubleVector = makeNullableFlatVector( - {0.5, - -2.0, - 5.0, - 1.0, - std::nullopt, - kNan, - kInf, - -kInf, - 0.0, - -0.0, - 0.00000001}); - auto expected = makeNullableFlatVector( - {6, -3, 1, std::nullopt, std::nullopt, 0, 0, 0, 0, 0, 100000000}, - INTERVAL_DAY_TIME()); - assertExpression("c0 / c1", intervalVector, doubleVector, expected); - - intervalVector = makeFlatVector( - {1, 1, 1, 1, kLongMax, kLongMin, kLongMax, kLongMin}, - INTERVAL_DAY_TIME()); - doubleVector = makeFlatVector( - {0.0, -0.0, 4.9e-324, -4.9e-324, 0.1, 0.1, -0.1, -0.1}); - expected = makeFlatVector( - {kLongMax, - kLongMin, - kLongMax, - kLongMin, - kLongMax, - kLongMin, - kLongMin, - kLongMax}, - INTERVAL_DAY_TIME()); - assertExpression("c0 / c1", intervalVector, doubleVector, expected); + // Test interval day time and interval year month types divided by double. + testIntervalDivide(INTERVAL_DAY_TIME()); + testIntervalDivide(INTERVAL_YEAR_MONTH()); } TEST_F(ArithmeticTest, multiply) { @@ -189,49 +233,11 @@ TEST_F(ArithmeticTest, multiply) { {-1}, "integer overflow: -2147483648 * -1"); - // Test multiplication of interval type with bigint. - auto intervalVector = makeNullableFlatVector( - {1, 2, 3, std::nullopt, 10, 20}, INTERVAL_DAY_TIME()); - auto bigintVector = makeNullableFlatVector( - {1, std::nullopt, 3, 4, kLongMax, kLongMin}); - auto expected = makeNullableFlatVector( - {1, std::nullopt, 9, std::nullopt, -10, 0}, INTERVAL_DAY_TIME()); - assertExpression("c0 * c1", intervalVector, bigintVector, expected); - assertExpression("c1 * c0", intervalVector, bigintVector, expected); - - // Test multiplication of interval type with double. - intervalVector = makeNullableFlatVector( - {1, 2, 3, std::nullopt, 10, 20, 30, 40, 1000, 1000, 1000, 1000}, - INTERVAL_DAY_TIME()); - auto doubleVector = makeNullableFlatVector( - {-1.8, - 2.1, - kNan, - 4.2, - 0.0, - -0.0, - kInf, - -kInf, - 9223372036854775807.01, - -9223372036854775808.01, - 1.7e308, - -1.7e308}); - expected = makeNullableFlatVector( - {-1, - 4, - 0, - std::nullopt, - 0, - 0, - kLongMax, - kLongMin, - kLongMax, - kLongMin, - kLongMax, - kLongMin}, - INTERVAL_DAY_TIME()); - assertExpression("c0 * c1", intervalVector, doubleVector, expected); - assertExpression("c1 * c0", intervalVector, doubleVector, expected); + // Test multiplication of interval day time type with bigint and double. + testIntervalMultiply(INTERVAL_DAY_TIME()); + + // Test multiplication of interval year month type with integer and double. + testIntervalMultiply(INTERVAL_YEAR_MONTH()); } TEST_F(ArithmeticTest, mod) {