Skip to content

Commit

Permalink
Add unscaled_value Spark function (#6013)
Browse files Browse the repository at this point in the history
Summary:
Fixes #6012

Pull Request resolved: #6013

Reviewed By: pedroerp

Differential Revision: D51072575

Pulled By: mbasmanova

fbshipit-source-id: ef617f92da9c740bc2c4f060c80faf9ef1f92eeb
  • Loading branch information
liujiayi771 authored and facebook-github-bot committed Nov 8, 2023
1 parent f73648a commit 703f2f2
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 1 deletion.
8 changes: 8 additions & 0 deletions velox/docs/functions/spark/decimal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
=================
Decimal functions
=================

.. spark:function:: unscaled_value(x) -> bigint
Return the unscaled bigint value of a short decimal ``x``.
Supported type is: SHORT_DECIMAL.
1 change: 1 addition & 0 deletions velox/docs/spark_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Spark Functions

functions/spark/math
functions/spark/bitwise
functions/spark/decimal
functions/spark/comparison
functions/spark/string
functions/spark/datetime
Expand Down
3 changes: 2 additions & 1 deletion velox/functions/sparksql/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ add_library(
RegisterCompare.cpp
Size.cpp
SplitFunctions.cpp
String.cpp)
String.cpp
UnscaledValueFunction.cpp)

# GCC 12 has a bug where it does not respect "pragma ignore" directives and ends
# up failing compilation in an openssl header included by a hash-related
Expand Down
7 changes: 7 additions & 0 deletions velox/functions/sparksql/Register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "velox/functions/sparksql/RegisterCompare.h"
#include "velox/functions/sparksql/Size.h"
#include "velox/functions/sparksql/String.h"
#include "velox/functions/sparksql/UnscaledValueFunction.h"

namespace facebook::velox::functions {
extern void registerElementAtFunction(
Expand Down Expand Up @@ -269,6 +270,12 @@ void registerFunctions(const std::string& prefix) {
{prefix + "might_contain"});

registerArrayMinMaxFunctions(prefix);

// Register decimal vector functions.
exec::registerVectorFunction(
prefix + "unscaled_value",
unscaledValueSignatures(),
makeUnscaledValue());
}

} // namespace sparksql
Expand Down
63 changes: 63 additions & 0 deletions velox/functions/sparksql/UnscaledValueFunction.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/sparksql/UnscaledValueFunction.h"

#include "velox/expression/DecodedArgs.h"

namespace facebook::velox::functions::sparksql {
namespace {

// Return the unscaled bigint value of a decimal, assuming it
// fits in a bigint. Only short decimal input is accepted.
class UnscaledValueFunction final : public exec::VectorFunction {
void apply(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const final {
VELOX_USER_CHECK(
args[0]->type()->isShortDecimal(),
"Expect short decimal type, but got: {}",
args[0]->type());
exec::DecodedArgs decodedArgs(rows, args, context);
auto decimalVector = decodedArgs.at(0);
context.ensureWritable(rows, BIGINT(), result);
result->clearNulls(rows);
auto flatResult =
result->asUnchecked<FlatVector<int64_t>>()->mutableRawValues();
rows.applyToSelected([&](auto row) {
flatResult[row] = decimalVector->valueAt<int64_t>(row);
});
}
};
} // namespace

std::vector<std::shared_ptr<exec::FunctionSignature>>
unscaledValueSignatures() {
return {exec::FunctionSignatureBuilder()
.integerVariable("precision")
.integerVariable("scale")
.returnType("bigint")
.argumentType("DECIMAL(precision, scale)")
.build()};
}

std::unique_ptr<exec::VectorFunction> makeUnscaledValue() {
return std::make_unique<UnscaledValueFunction>();
}

} // namespace facebook::velox::functions::sparksql
25 changes: 25 additions & 0 deletions velox/functions/sparksql/UnscaledValueFunction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include "velox/expression/VectorFunction.h"

namespace facebook::velox::functions::sparksql {

std::vector<std::shared_ptr<exec::FunctionSignature>> unscaledValueSignatures();
std::unique_ptr<exec::VectorFunction> makeUnscaledValue();

} // namespace facebook::velox::functions::sparksql
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_executable(
SortArrayTest.cpp
SplitFunctionsTest.cpp
StringTest.cpp
UnscaledValueFunctionTest.cpp
XxHash64Test.cpp)

add_test(velox_functions_spark_test velox_functions_spark_test)
Expand Down
42 changes: 42 additions & 0 deletions velox/functions/sparksql/tests/UnscaledValueFunctionTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

using namespace facebook::velox::test;

namespace facebook::velox::functions::sparksql::test {
namespace {

class UnscaledValueFunctionTest : public SparkFunctionBaseTest {};

TEST_F(UnscaledValueFunctionTest, unscaledValue) {
auto testUnscaledValue = [&](const std::vector<int64_t>& unscaledValue,
const TypePtr& decimalType) {
auto input = makeFlatVector<int64_t>(unscaledValue, decimalType);
auto expected = makeFlatVector<int64_t>(unscaledValue);
auto result = evaluate("unscaled_value(c0)", makeRowVector({input}));
assertEqualVectors(expected, result);
};

testUnscaledValue({1000, 2000, -3000, -4000}, DECIMAL(18, 3));

VELOX_ASSERT_THROW(
testUnscaledValue({1000, 2000, -3000, -4000}, DECIMAL(20, 3)),
"Expect short decimal type, but got: DECIMAL(20, 3)");
}
} // namespace
} // namespace facebook::velox::functions::sparksql::test

0 comments on commit 703f2f2

Please sign in to comment.