From 3884939853312782e5252084f28cd3f806d705fe Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Tue, 12 Nov 2024 10:06:06 -0800 Subject: [PATCH] Add support for allow-precision-loss in decimal operations (#10383) Summary: Each of the decimal operation functions is registered as two functions such as `add_deny_precision_loss` and `add`. When allowing precision loss, establishing the result type of an arithmetic operation happens according to Hive behavior and SQL ANSI 2011 specification, i.e. rounding the decimal part of the result if an exact representation is not possible. Otherwise, NULL is returned in those cases, as previously. When not allowing precision loss, not rounding the decimal part. For example, | decimal(38, 7) + decimal(10, 0) result type | 1.1232154 + 1| decimal(38, 18) * decimal(38, 18)| 0.1234567891011 * 1234.1 -- | -- | -- | -- | -- allow precision loss | decimal(38, 6) | 2.123215 | decimal(38, 6) | 152.358023 deny precision loss | decimal(38, 7) | 2.1232154 | decimal(38, 36) | NULL ``` spark-sql (default)> set spark.sql.decimalOperations.allowPrecisionLoss=true; spark-sql (default)> select cast(0.1234567891011 as decimal(38, 18)) * cast(1234.1 as decimal(38, 18)); 152.358023 spark-sql (default)> set spark.sql.decimalOperations.allowPrecisionLoss=false; spark-sql (default)> select cast(0.1234567891011 as decimal(38, 18)) * cast(1234.1 as decimal(38, 18)); NULL ``` Spark implementation: https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala#L814 Pull Request resolved: https://github.com/facebookincubator/velox/pull/10383 Reviewed By: pedroerp Differential Revision: D65612198 Pulled By: kevinwilfong fbshipit-source-id: 4910aaeb0e375dbe8817c5f3fb41185c67c6dd5b --- velox/docs/functions/spark/decimal.rst | 90 ++++++- .../functions/sparksql/DecimalArithmetic.cpp | 229 +++++++++++++----- velox/functions/sparksql/DecimalUtil.h | 10 + .../sparksql/tests/DecimalArithmeticTest.cpp | 59 +++++ .../sparksql/tests/DecimalUtilTest.cpp | 12 + 5 files changed, 332 insertions(+), 68 deletions(-) diff --git a/velox/docs/functions/spark/decimal.rst b/velox/docs/functions/spark/decimal.rst index 19eee325f4b3..e6901166c627 100644 --- a/velox/docs/functions/spark/decimal.rst +++ b/velox/docs/functions/spark/decimal.rst @@ -2,15 +2,46 @@ Decimal Operators ================= -When calculating the result precision and scale of arithmetic operators, -the formulas follow Hive which is based on the SQL standard and MS SQL: +The result precision and scale computation of arithmetic operators contains two stages. +First stage computes precision and scale using formulas based on the SQL standard and Hive when allow-precision-loss is true. +The result may exceed maximum allowed precision of 38. + +Second stage caps precision at 38 and either reduces the scale or not depending on allow-precision-loss flag. + +For example, addition of decimal(38, 7) and decimal(10, 0) requires precision of 39 and scale of 7. +Since precision exceeds 38 it needs to be capped. When allow-precision-loss, precision is capped at 38 and scale is reduced by 1 to 6. +When allow-precision-loss is false, precision is capped at 38 as well, but scale is kept at 7. +With allow-precision-loss all additions will succeed, but accuracy (number of digits after period) of some operations will be reduced. +Without allow-precision-loss, some additions will return NULL. + +For example, + +The following queries keep accuracy or return NULL when allow-precision-loss is false: + +:: + + select cast('1.1232154' as decimal(38, 7)) + cast('1' as decimal(10, 0)); -- 2.123215 + select cast('9999999999999999999999999999999.2345678' as decimal(38, 7)) + cast('1' as decimal(10, 0)); -- NULL + +These same operations succeed when allow-precision-loss is true: + +:: + + select cast('1.1232154' as decimal(38, 7)) + cast('1' as decimal(10, 0)); -- 2.12321, lost the last digit + select cast('9999999999999999999999999999999.2345678' as decimal(38, 7)) + cast('1' as decimal(10, 0)); -- 10000000000000000000000000000000.234568 + +Decimal Precision and Scale Computation Formulas +------------------------------------------------ + +The HiveQL behavior: https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf -https://msdn.microsoft.com/en-us/library/ms190476.aspx +Additionally, the computation of decimal division adapts to the allow-precision-loss flag, +while the decimal addition, subtraction, and multiplication do not. Addition and Subtraction ------------------------- +~~~~~~~~~~~~~~~~~~~~~~~~ :: @@ -18,7 +49,7 @@ Addition and Subtraction s = max(s1, s2) Multiplication --------------- +~~~~~~~~~~~~~~ :: @@ -26,25 +57,60 @@ Multiplication s = s1 + s2 Division --------- +~~~~~~~~ +When allow-precision-loss is true: :: p = p1 - s1 + s2 + max(6, s1 + p2 + 1) s = max(6, s1 + p2 + 1) -For above arithmetic operators, when the precision of result exceeds 38, -caps p at 38 and reduces the scale, in order to prevent the truncation of -the integer part of the decimals. Below formula illustrates how the result -precision and scale are adjusted. +When allow-precision-loss is false: + +:: + + wholeDigits = min(38, p1 - s1 + s2); + fractionalDigits = min(38, max(6, s1 + p2 + 1)); + p = wholeDigits + fractionalDigits + s = fractionalDigits + +Decimal Precision and Scale Adjustment +-------------------------------------- + +When allow-precision-loss is true, rounds the decimal part of the result if an exact representation is not possible. +Otherwise, returns NULL. +Notice: some operations succeed if precision loss is allowed and return NULL if not. + +For example, + +:: + + select cast(0.1234567891011 as decimal(38, 18)) * cast(1234.1 as decimal(38, 18)); + -- 152.358023 if allow-precision-loss, NULL otherwise. + +Below formula illustrates how the result precision and scale are adjusted. :: precision = 38 scale = max(38 - (p - s), min(s, 6)) -Users experience runtime errors when the actual result cannot be represented -with the calculated decimal type. +When precision loss is not allowed, caps p at 38, and keeps scale as is. +The below formula shows how the precision and scale are adjusted for decimal addition, subtraction, and multiplication. + +:: + + precision = 38 + scale = min(38, s) + +Decimal division uses a different formula: + +:: + + precision = 38 + scale = fractionalDigits - (wholeDigits + fractionalDigits - 38) / 2 - 1 + +Returns NULL when the actual result cannot be represented with the calculated decimal type. Decimal Functions ----------------- diff --git a/velox/functions/sparksql/DecimalArithmetic.cpp b/velox/functions/sparksql/DecimalArithmetic.cpp index 61599bce10ea..7f5b48ea7740 100644 --- a/velox/functions/sparksql/DecimalArithmetic.cpp +++ b/velox/functions/sparksql/DecimalArithmetic.cpp @@ -21,15 +21,18 @@ namespace facebook::velox::functions::sparksql { namespace { +static constexpr const char* kDenyPrecisionLoss = "_deny_precision_loss"; + struct DecimalAddSubtractBase { protected: + template void initializeBase(const std::vector& inputTypes) { auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); aScale_ = aScale; bScale_ = bScale; - auto [rPrecision, rScale] = - computeResultPrecisionScale(aPrecision, aScale_, bPrecision, bScale_); + auto [rPrecision, rScale] = computeResultPrecisionScale( + aPrecision, aScale_, bPrecision, bScale_); rPrecision_ = rPrecision; rScale_ = rScale; aRescale_ = computeRescaleFactor(aScale_, bScale_); @@ -252,11 +255,13 @@ struct DecimalAddSubtractBase { } } - // Computes the result precision and scale for decimal add and subtract - // operations following Hive's formulas. - // If result is representable with long decimal, the result - // scale is the maximum of 'aScale' and 'bScale'. If not, reduces result scale - // and caps the result precision at 38. + // When `allowPrecisionLoss` is true, computes the result precision and scale + // for decimal add and subtract operations following Hive's formulas. If + // result is representable with long decimal, the result scale is the maximum + // of 'aScale' and 'bScale'. If not, reduces result scale and caps the result + // precision at 38. + // When `allowPrecisionLoss` is false, caps p and s at 38. + template static std::pair computeResultPrecisionScale( uint8_t aPrecision, uint8_t aScale, @@ -265,7 +270,11 @@ struct DecimalAddSubtractBase { auto precision = std::max(aPrecision - aScale, bPrecision - bScale) + std::max(aScale, bScale) + 1; auto scale = std::max(aScale, bScale); - return sparksql::DecimalUtil::adjustPrecisionScale(precision, scale); + if constexpr (allowPrecisionLoss) { + return sparksql::DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + return sparksql::DecimalUtil::bounded(precision, scale); + } } static uint8_t computeRescaleFactor(uint8_t fromScale, uint8_t toScale) { @@ -280,7 +289,7 @@ struct DecimalAddSubtractBase { uint8_t rScale_; }; -template +template struct DecimalAddFunction : DecimalAddSubtractBase { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -290,7 +299,7 @@ struct DecimalAddFunction : DecimalAddSubtractBase { const core::QueryConfig& /*config*/, A* /*a*/, B* /*b*/) { - initializeBase(inputTypes); + initializeBase(inputTypes); } template @@ -299,7 +308,7 @@ struct DecimalAddFunction : DecimalAddSubtractBase { } }; -template +template struct DecimalSubtractFunction : DecimalAddSubtractBase { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -309,7 +318,7 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase { const core::QueryConfig& /*config*/, A* /*a*/, B* /*b*/) { - initializeBase(inputTypes); + initializeBase(inputTypes); } template @@ -318,7 +327,7 @@ struct DecimalSubtractFunction : DecimalAddSubtractBase { } }; -template +template struct DecimalMultiplyFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -330,10 +339,16 @@ struct DecimalMultiplyFunction { B* /*b*/) { auto [aPrecision, aScale] = getDecimalPrecisionScale(*inputTypes[0]); auto [bPrecision, bScale] = getDecimalPrecisionScale(*inputTypes[1]); - auto [rPrecision, rScale] = DecimalUtil::adjustPrecisionScale( - aPrecision + bPrecision + 1, aScale + bScale); - rPrecision_ = rPrecision; - deltaScale_ = aScale + bScale - rScale; + std::pair rPrecisionScale; + if constexpr (allowPrecisionLoss) { + rPrecisionScale = DecimalUtil::adjustPrecisionScale( + aPrecision + bPrecision + 1, aScale + bScale); + } else { + rPrecisionScale = + DecimalUtil::bounded(aPrecision + bPrecision + 1, aScale + bScale); + } + rPrecision_ = rPrecisionScale.first; + deltaScale_ = aScale + bScale - rPrecisionScale.second; } template @@ -426,7 +441,7 @@ struct DecimalMultiplyFunction { int32_t deltaScale_; }; -template +template struct DecimalDivideFunction { VELOX_DEFINE_FUNCTION_TYPES(TExec); @@ -453,14 +468,30 @@ struct DecimalDivideFunction { } private: + // When allowing precision loss, computes the result precision and scale + // following Hive's formulas. When denying precision loss, calculates the + // number of whole digits and fraction digits. If the total number of digits + // exceed 38, we reduce both the number of fraction digits and whole digits to + // fit within this limit. static std::pair computeResultPrecisionScale( uint8_t aPrecision, uint8_t aScale, uint8_t bPrecision, uint8_t bScale) { - auto scale = std::max(6, aScale + bPrecision + 1); - auto precision = aPrecision - aScale + bScale + scale; - return DecimalUtil::adjustPrecisionScale(precision, scale); + if constexpr (allowPrecisionLoss) { + auto scale = std::max(6, aScale + bPrecision + 1); + auto precision = aPrecision - aScale + bScale + scale; + return DecimalUtil::adjustPrecisionScale(precision, scale); + } else { + auto wholeDigits = std::min(38, aPrecision - aScale + bScale); + auto fractionDigits = std::min(38, std::max(6, aScale + bPrecision + 1)); + auto diff = (wholeDigits + fractionDigits) - 38; + if (diff > 0) { + fractionDigits -= diff / 2 + 1; + wholeDigits = 38 - fractionDigits; + } + return DecimalUtil::bounded(wholeDigits + fractionDigits, fractionDigits); + } } uint8_t aRescale_; @@ -507,16 +538,24 @@ void registerDecimalBinary( ShortDecimal>({name}, constraints); } +// Used in function registration to generate the string to cap value at 38. +std::string bounded(const std::string& value) { + return fmt::format("({}) <= 38 ? ({}) : 38", value, value); +} + std::vector makeConstraints( const std::string& rPrecision, - const std::string& rScale) { - std::string finalScale = fmt::format( - "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", - rPrecision, - rScale, - rScale, - rPrecision, - rScale); + const std::string& rScale, + bool allowPrecisionLoss) { + std::string finalScale = allowPrecisionLoss + ? fmt::format( + "({}) <= 38 ? ({}) : max(({}) - ({}) + 38, min(({}), 6))", + rPrecision, + rScale, + rScale, + rPrecision, + rScale) + : bounded(rScale); return { exec::SignatureVariable( P3::name(), @@ -527,8 +566,7 @@ std::vector makeConstraints( S3::name(), finalScale, exec::ParameterType::kIntegerParameter)}; } -template