Skip to content

Commit

Permalink
Fix the signature of the decimal multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Apr 2, 2024
1 parent f64795f commit 67b8804
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
23 changes: 22 additions & 1 deletion velox/functions/prestosql/DecimalFunctions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,25 @@ template <typename TExec>
struct DecimalMultiplyFunction {
VELOX_DEFINE_FUNCTION_TYPES(TExec);

template <typename A, typename B>
void initialize(
const std::vector<TypePtr>& inputTypes,
const core::QueryConfig& /*config*/,
A* /*a*/,
B* /*b*/) {
const auto aType = inputTypes[0];
const auto bType = inputTypes[1];
const auto [aPrecision, aScale] = getDecimalPrecisionScale(*aType);
const auto [bPrecision, bScale] = getDecimalPrecisionScale(*bType);
const auto rPrecision = std::min(38, aPrecision + bPrecision);
const auto rScale = aScale + bScale;
VELOX_USER_CHECK_LE(
rScale,
rPrecision,
"DECIMAL scale must be in range [0, {}].",
rPrecision);
}

template <typename R, typename A, typename B>
void call(R& out, const A& a, const B& b) {
out = checkedMultiply<R>(checkedMultiply<R>(R(a), R(b)), R(1));
Expand Down Expand Up @@ -364,7 +383,9 @@ void registerDecimalMultiply(const std::string& prefix) {
exec::SignatureVariable(
S3::name(),
fmt::format(
"{a_scale} + {b_scale}",
"min({a_scale} + {b_scale}, min(38, {a_precision} + {b_precision}))",
fmt::arg("a_precision", P1::name()),
fmt::arg("b_precision", P2::name()),
fmt::arg("a_scale", S1::name()),
fmt::arg("b_scale", S2::name())),
exec::ParameterType::kIntegerParameter),
Expand Down
8 changes: 8 additions & 0 deletions velox/functions/prestosql/tests/DecimalArithmeticTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ TEST_F(DecimalArithmeticTest, multiply) {
HugeInt::build(0x08FFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF)},
DECIMAL(38, 0))}),
"Decimal overflow. Value '119630519620642428561342635425231011830' is not in the range of Decimal Type");

// The sum of input scales exceeds result precision.
VELOX_ASSERT_THROW(
evaluate(
"c0 * c0",
makeRowVector(
{makeFlatVector<int128_t>({1000, 2000}, DECIMAL(38, 30))})),
"DECIMAL scale must be in range [0, 38].");
}

TEST_F(DecimalArithmeticTest, decimalDivTest) {
Expand Down

0 comments on commit 67b8804

Please sign in to comment.