diff --git a/velox/docs/functions/presto/math.rst b/velox/docs/functions/presto/math.rst index e8bb8e708dc5..96ea786eb588 100644 --- a/velox/docs/functions/presto/math.rst +++ b/velox/docs/functions/presto/math.rst @@ -159,14 +159,17 @@ Mathematical Functions Returns the base-``radix`` representation of ``x``. ``radix`` must be between 2 and 36. -.. function:: truncate(x) -> double +.. function:: truncate(x) -> [same as x] Returns x rounded to integer by dropping digits after decimal point. + Supported types of ``x`` are: REAL and DOUBLE. -.. function:: truncate(x, n) -> double +.. function:: truncate(x, n) -> [same as x] :noindex: Returns x truncated to n decimal places. n can be negative to truncate n digits left of the decimal point. + Supported types of ``x`` are: REAL and DOUBLE. + ``n`` is an INTEGER. .. function:: width_bucket(x, bound1, bound2, n) -> bigint diff --git a/velox/functions/prestosql/Arithmetic.h b/velox/functions/prestosql/Arithmetic.h index 85b5e86b4cd0..3757284bf2f4 100644 --- a/velox/functions/prestosql/Arithmetic.h +++ b/velox/functions/prestosql/Arithmetic.h @@ -580,15 +580,23 @@ struct EulerConstantFunction { } }; -template +template struct TruncateFunction { FOLLY_ALWAYS_INLINE void call(double& result, double a) { result = std::trunc(a); } + FOLLY_ALWAYS_INLINE void call(float& result, float a) { + result = std::trunc(a); + } + FOLLY_ALWAYS_INLINE void call(double& result, double a, int32_t n) { result = truncate(a, n); } + + FOLLY_ALWAYS_INLINE void call(float& result, float a, int32_t n) { + result = truncate(a, n); + } }; template diff --git a/velox/functions/prestosql/ArithmeticImpl.h b/velox/functions/prestosql/ArithmeticImpl.h index 6ea7d532aae7..d126ba5bd3fe 100644 --- a/velox/functions/prestosql/ArithmeticImpl.h +++ b/velox/functions/prestosql/ArithmeticImpl.h @@ -133,16 +133,14 @@ T ceil(const T& arg) { return results; } -FOLLY_ALWAYS_INLINE double truncate( - const double& number, - const int32_t& decimals = 0) { +FOLLY_ALWAYS_INLINE double truncate(double number, int32_t decimals) { const bool decNegative = (decimals < 0); const auto log10Size = DoubleUtil::kPowersOfTen.size(); // 309 if (decNegative && decimals <= -log10Size) { return 0.0; } - const uint64_t absDec = decNegative ? -decimals : decimals; + const uint64_t absDec = std::abs(decimals); const double tmp = (absDec < log10Size) ? DoubleUtil::kPowersOfTen[absDec] : std::pow(10.0, (double)absDec); diff --git a/velox/functions/prestosql/registration/MathematicalFunctionsRegistration.cpp b/velox/functions/prestosql/registration/MathematicalFunctionsRegistration.cpp index 450a06d87223..0cd45a1c5b93 100644 --- a/velox/functions/prestosql/registration/MathematicalFunctionsRegistration.cpp +++ b/velox/functions/prestosql/registration/MathematicalFunctionsRegistration.cpp @@ -22,6 +22,14 @@ namespace facebook::velox::functions { namespace { + +void registerTruncate(const std::vector& names) { + registerFunction(names); + registerFunction(names); + registerFunction(names); + registerFunction(names); +} + void registerMathFunctions(const std::string& prefix) { registerUnaryNumeric({prefix + "ceil", prefix + "ceiling"}); registerUnaryNumeric({prefix + "floor"}); @@ -98,9 +106,9 @@ void registerMathFunctions(const std::string& prefix) { {prefix + "to_base"}); registerFunction({prefix + "pi"}); registerFunction({prefix + "e"}); - registerFunction({prefix + "truncate"}); - registerFunction( - {prefix + "truncate"}); + + registerTruncate({prefix + "truncate"}); + registerFunction< CosineSimilarityFunction, double, diff --git a/velox/functions/prestosql/tests/ArithmeticTest.cpp b/velox/functions/prestosql/tests/ArithmeticTest.cpp index 272c377d3193..953f8ef7e3b9 100644 --- a/velox/functions/prestosql/tests/ArithmeticTest.cpp +++ b/velox/functions/prestosql/tests/ArithmeticTest.cpp @@ -776,12 +776,33 @@ TEST_F(ArithmeticTest, clamp) { EXPECT_EQ(clamp(123456, 1, -1), -1); } -TEST_F(ArithmeticTest, truncate) { - const auto truncate = [&](std::optional a, - std::optional n = 0) { - return evaluateOnce("truncate(c0,c1)", a, n); +TEST_F(ArithmeticTest, truncateDouble) { + const auto truncate = [&](std::optional d) { + const auto r = evaluateOnce("truncate(c0)", d); + + // truncate(d) == truncate(d, 0) + if (d.has_value() && std::isfinite(d.value())) { + const auto otherResult = + evaluateOnce("truncate(c0, 0::integer)", d); + + VELOX_CHECK_EQ(r.value(), otherResult.value()); + } + + return r; }; + const auto truncateN = [&](std::optional d, + std::optional n) { + return evaluateOnce("truncate(c0, c1)", d, n); + }; + + EXPECT_EQ(truncate(0), 0); + EXPECT_EQ(truncate(1.5), 1); + EXPECT_EQ(truncate(-1.5), -1); + EXPECT_EQ(truncate(std::nullopt), std::nullopt); + EXPECT_THAT(truncate(kNan), IsNan()); + EXPECT_THAT(truncate(kInf), IsInf()); + EXPECT_EQ(truncate(0), 0); EXPECT_EQ(truncate(1.5), 1); EXPECT_EQ(truncate(-1.5), -1); @@ -789,37 +810,79 @@ TEST_F(ArithmeticTest, truncate) { EXPECT_THAT(truncate(kNan), IsNan()); EXPECT_THAT(truncate(kInf), IsInf()); - EXPECT_EQ(truncate(0, 0), 0); - EXPECT_EQ(truncate(1.5, 0), 1); - EXPECT_EQ(truncate(-1.5, 0), -1); - EXPECT_EQ(truncate(std::nullopt, 0), std::nullopt); - EXPECT_EQ(truncate(1.5, std::nullopt), std::nullopt); - EXPECT_THAT(truncate(kNan, 0), IsNan()); - EXPECT_THAT(truncate(kNan, 1), IsNan()); - EXPECT_THAT(truncate(kInf, 0), IsInf()); - EXPECT_THAT(truncate(kInf, 1), IsInf()); - - EXPECT_DOUBLE_EQ(truncate(1.5678, 2).value(), 1.56); - EXPECT_DOUBLE_EQ(truncate(-1.5678, 2).value(), -1.56); - EXPECT_DOUBLE_EQ(truncate(1.333, -1).value(), 0); - EXPECT_DOUBLE_EQ(truncate(3.54555, 2).value(), 3.54); - EXPECT_DOUBLE_EQ(truncate(1234, 1).value(), 1234); - EXPECT_DOUBLE_EQ(truncate(1234, -1).value(), 1230); - EXPECT_DOUBLE_EQ(truncate(1234.56, 1).value(), 1234.5); - EXPECT_DOUBLE_EQ(truncate(1234.56, -1).value(), 1230.0); - EXPECT_DOUBLE_EQ(truncate(1239.999, 2).value(), 1239.99); - EXPECT_DOUBLE_EQ(truncate(1239.999, -2).value(), 1200.0); + EXPECT_EQ(truncateN(1.5, std::nullopt), std::nullopt); + EXPECT_THAT(truncateN(kNan, 1), IsNan()); + EXPECT_THAT(truncateN(kInf, 1), IsInf()); + + EXPECT_DOUBLE_EQ(truncateN(1.5678, 2).value(), 1.56); + EXPECT_DOUBLE_EQ(truncateN(-1.5678, 2).value(), -1.56); + EXPECT_DOUBLE_EQ(truncateN(1.333, -1).value(), 0); + EXPECT_DOUBLE_EQ(truncateN(3.54555, 2).value(), 3.54); + EXPECT_DOUBLE_EQ(truncateN(1234, 1).value(), 1234); + EXPECT_DOUBLE_EQ(truncateN(1234, -1).value(), 1230); + EXPECT_DOUBLE_EQ(truncateN(1234.56, 1).value(), 1234.5); + EXPECT_DOUBLE_EQ(truncateN(1234.56, -1).value(), 1230.0); + EXPECT_DOUBLE_EQ(truncateN(1239.999, 2).value(), 1239.99); + EXPECT_DOUBLE_EQ(truncateN(1239.999, -2).value(), 1200.0); EXPECT_DOUBLE_EQ( - truncate(123456789012345678901.23, 3).value(), 123456789012345678901.23); + truncateN(123456789012345678901.23, 3).value(), 123456789012345678901.23); EXPECT_DOUBLE_EQ( - truncate(-123456789012345678901.23, 3).value(), + truncateN(-123456789012345678901.23, 3).value(), -123456789012345678901.23); EXPECT_DOUBLE_EQ( - truncate(123456789123456.999, 2).value(), 123456789123456.99); - EXPECT_DOUBLE_EQ(truncate(123456789012345678901.0, -21).value(), 0.0); - EXPECT_DOUBLE_EQ(truncate(123456789012345678901.23, -21).value(), 0.0); - EXPECT_DOUBLE_EQ(truncate(123456789012345678901.0, -21).value(), 0.0); - EXPECT_DOUBLE_EQ(truncate(123456789012345678901.23, -21).value(), 0.0); + truncateN(123456789123456.999, 2).value(), 123456789123456.99); + EXPECT_DOUBLE_EQ(truncateN(123456789012345678901.0, -21).value(), 0.0); + EXPECT_DOUBLE_EQ(truncateN(123456789012345678901.23, -21).value(), 0.0); + EXPECT_DOUBLE_EQ(truncateN(123456789012345678901.0, -21).value(), 0.0); + EXPECT_DOUBLE_EQ(truncateN(123456789012345678901.23, -21).value(), 0.0); +} + +TEST_F(ArithmeticTest, truncateReal) { + const auto truncate = [&](std::optional d) { + const auto r = evaluateOnce("truncate(c0)", d); + + // truncate(d) == truncate(d, 0) + if (d.has_value() && std::isfinite(d.value())) { + const auto otherResult = + evaluateOnce("truncate(c0, 0::integer)", d); + + VELOX_CHECK_EQ(r.value(), otherResult.value()); + } + + return r; + }; + + const auto truncateN = [&](std::optional d, std::optional n) { + return evaluateOnce("truncate(c0, c1)", d, n); + }; + + EXPECT_EQ(truncate(0), 0); + EXPECT_EQ(truncate(1.5), 1); + EXPECT_EQ(truncate(-1.5), -1); + + EXPECT_EQ(truncate(std::nullopt), std::nullopt); + EXPECT_THAT(truncate(kNan), IsNan()); + EXPECT_THAT(truncate(kInf), IsInf()); + + EXPECT_FLOAT_EQ(truncateN(123.456, 0).value(), 123); + EXPECT_FLOAT_EQ(truncateN(123.456, 1).value(), 123.4); + EXPECT_FLOAT_EQ(truncateN(123.456, 2).value(), 123.45); + EXPECT_FLOAT_EQ(truncateN(123.456, 3).value(), 123.456); + EXPECT_FLOAT_EQ(truncateN(123.456, 4).value(), 123.456); + + EXPECT_FLOAT_EQ(truncateN(123.456, -1).value(), 120); + EXPECT_FLOAT_EQ(truncateN(123.456, -2).value(), 100); + EXPECT_FLOAT_EQ(truncateN(123.456, -3).value(), 0); + + EXPECT_FLOAT_EQ(truncateN(-123.456, 0).value(), -123); + EXPECT_FLOAT_EQ(truncateN(-123.456, 1).value(), -123.4); + EXPECT_FLOAT_EQ(truncateN(-123.456, 2).value(), -123.45); + EXPECT_FLOAT_EQ(truncateN(-123.456, 3).value(), -123.456); + EXPECT_FLOAT_EQ(truncateN(-123.456, 4).value(), -123.456); + + EXPECT_FLOAT_EQ(truncateN(-123.456, -1).value(), -120); + EXPECT_FLOAT_EQ(truncateN(-123.456, -2).value(), -100); + EXPECT_FLOAT_EQ(truncateN(-123.456, -3).value(), 0); } TEST_F(ArithmeticTest, wilsonIntervalLower) {