diff --git a/velox/expression/tests/CastExprTest.cpp b/velox/expression/tests/CastExprTest.cpp index 6e83395e2bcc4..c8822f016d10e 100644 --- a/velox/expression/tests/CastExprTest.cpp +++ b/velox/expression/tests/CastExprTest.cpp @@ -2086,12 +2086,12 @@ TEST_F(CastExprTest, doubleToDecimal) { DOUBLE(), DECIMAL(38, 2), {std::numeric_limits::max()}, - "Cannot cast DOUBLE '1.7976931348623157E308' to DECIMAL(38, 2). Result overflows."); + "Result overflows."); testThrow( DOUBLE(), DECIMAL(38, 2), {std::numeric_limits::lowest()}, - "Cannot cast DOUBLE '-1.7976931348623157E308' to DECIMAL(38, 2). Result overflows."); + "Result overflows."); testCast( makeConstant(std::numeric_limits::min(), 1), makeConstant(0, 1, DECIMAL(38, 2))); @@ -2169,10 +2169,7 @@ TEST_F(CastExprTest, realToDecimal) { DECIMAL(38, 18))); testThrow( - REAL(), - DECIMAL(10, 2), - {9999999999999999999999.99}, - "Cannot cast REAL '9.999999778196308E21' to DECIMAL(10, 2). Result overflows."); + REAL(), DECIMAL(10, 2), {9999999999999999999999.99}, "Result overflows."); testThrow( REAL(), DECIMAL(10, 2), @@ -2189,22 +2186,22 @@ TEST_F(CastExprTest, realToDecimal) { REAL(), DECIMAL(20, 2), {static_cast(DecimalUtil::kLongDecimalMax)}, - "Cannot cast REAL '9.999999680285692E37' to DECIMAL(20, 2). Result overflows."); + "Result overflows."); testThrow( REAL(), DECIMAL(20, 2), {static_cast(DecimalUtil::kLongDecimalMin)}, - "Cannot cast REAL '-9.999999680285692E37' to DECIMAL(20, 2). Result overflows."); + "Result overflows."); testThrow( REAL(), DECIMAL(38, 2), {std::numeric_limits::max()}, - "Cannot cast REAL '3.4028234663852886E38' to DECIMAL(38, 2). Result overflows."); + "Result overflows."); testThrow( REAL(), DECIMAL(38, 2), {std::numeric_limits::lowest()}, - "Cannot cast REAL '-3.4028234663852886E38' to DECIMAL(38, 2). Result overflows."); + "Result overflows."); testCast( makeConstant(std::numeric_limits::min(), 1), makeConstant(0, 1, DECIMAL(38, 2))); diff --git a/velox/type/DecimalUtil.h b/velox/type/DecimalUtil.h index 6c46032975bbf..e55a241652e19 100644 --- a/velox/type/DecimalUtil.h +++ b/velox/type/DecimalUtil.h @@ -225,41 +225,55 @@ class DecimalUtil { } else { // A double provides from 15 to 17 decimal digits, so at least 15 digits // are precise. - digits = 15; - if (value <= std::numeric_limits::min() || - value >= std::numeric_limits::max()) { - return Status::UserError("Result overflows."); - } + digits = 9; } - // Calculate the precise fractional digits. - const auto integralValue = - static_cast(value > 0 ? value : -value); - const auto integralDigits = - integralValue == 0 ? 0 : countDigits(integralValue); - const auto fractionDigits = digits - integralDigits; - /// Scales up the input value to keep all the precise fractional digits - /// before rounding. Convert value to long double type, as double * int128_t - /// returns int128_t and fractional digits are lost. No need to consider - /// 'toValue' becoming infinite as DOUBLE_MAX * 10^38 < LONG_DOUBLE_MAX. - const auto scaledValue = - (long double)value * DecimalUtil::kPowersOfTen[fractionDigits]; - - long double rounded; - if (scale > fractionDigits) { - rounded = std::round(scaledValue) * - DecimalUtil::kPowersOfTen[scale - fractionDigits]; + // long double rounded; + // if (scale > digits) { + // // Convert value to long double type, as double * int128_t returns + // // int128_t and fractional digits are lost. No need to consider 'toValue' + // // becoming infinite as DOUBLE_MAX * 10^38 < LONG_DOUBLE_MAX. + // const auto toValue = + // (long double)value * DecimalUtil::kPowersOfTen[digits]; + // rounded = std::round(toValue) * DecimalUtil::kPowersOfTen[scale - digits]; + // } else { + // const auto toValue = + // (long double)value * DecimalUtil::kPowersOfTen[scale]; + // rounded = std::round(toValue); + // } + long double toValue; + if (value > 0) { + toValue = std::nextafter(value, value + 0.1); } else { - rounded = std::round( - std::round(scaledValue) / - DecimalUtil::kPowersOfTen[fractionDigits - scale]); + toValue = std::nextafter(value, value - 0.1); } - - const auto result = folly::tryTo(rounded); - if (result.hasError()) { - return Status::UserError("Result overflows."); + TOutput rescaledValue; + if (scale > digits) { + toValue = toValue * DecimalUtil::kPowersOfTen[digits]; + // double sign = (double(0) < toValue) - (toValue < double(0)); + // toValue += 1e-9 * sign; + auto result = folly::tryTo(std::round(toValue)); + if (result.hasError()) { + return Status::UserError("Result overflows."); + } + rescaledValue = result.value() * DecimalUtil::kPowersOfTen[scale - digits]; + } else { + toValue = toValue * DecimalUtil::kPowersOfTen[scale]; + // Add the sign (-1, 0, 1) times a tiny value to fix floating point issues. + double sign = (double(0) < toValue) - (toValue < double(0)); + toValue += 1e-9 * sign; + result = folly::tryTo(std::round(toValue)); + if (result.hasError()) { + return Status::UserError("Result overflows."); + } + rescaledValue = result.value(); } - const TOutput rescaledValue = result.value(); + + // const auto result = folly::tryTo(std::round(toValue)); + // if (result.hasError()) { + // return Status::UserError("Result overflows."); + // } + // const TOutput rescaledValue = result.value(); if (!valueInPrecisionRange(rescaledValue, precision)) { return Status::UserError( "Result cannot fit in the given precision {}.", precision);