Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[native] Add support for REST based remote function #23568

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,10 +1342,11 @@ void PrestoServer::registerRemoteFunctions() {
} else {
VELOX_FAIL(
"To register remote functions using a json file path you need to "
"specify the remote server location using '{}', '{}' or '{}'.",
"specify the remote server location using '{}', '{}' or '{}' or {}.",
SystemConfig::kRemoteFunctionServerThriftAddress,
SystemConfig::kRemoteFunctionServerThriftPort,
SystemConfig::kRemoteFunctionServerThriftUdsPath);
SystemConfig::kRemoteFunctionServerThriftUdsPath,
SystemConfig::kRemoteFunctionServerRestURL);
}
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const {
return optionalProperty(kRemoteFunctionServerSerde).value();
}

std::string SystemConfig::remoteFunctionServerRestURL() const {
return optionalProperty(kRemoteFunctionServerRestURL).value();
}

int32_t SystemConfig::maxDriversPerTask() const {
return optionalProperty<int32_t>(kMaxDriversPerTask).value();
}
Expand Down
6 changes: 6 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ class SystemConfig : public ConfigBase {
static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{
"remote-function-server.thrift.uds-path"};

/// HTTP URL used by the remote function rest server.
static constexpr std::string_view kRemoteFunctionServerRestURL{
"remote-function-server.rest.url"};

/// Path where json files containing signatures for remote functions can be
/// found.
static constexpr std::string_view
Expand Down Expand Up @@ -714,6 +718,8 @@ class SystemConfig : public ConfigBase {

std::string remoteFunctionServerSerde() const;

std::string remoteFunctionServerRestURL() const;

int32_t maxDriversPerTask() const;

folly::Optional<int32_t> taskWriterCount() const;
Expand Down
12 changes: 10 additions & 2 deletions presto-native-execution/presto_cpp/main/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ add_library(
add_dependencies(presto_types presto_operators presto_type_converter velox_type
velox_type_fbhive)

target_link_libraries(presto_types presto_type_converter velox_type_fbhive
velox_hive_partition_function velox_tpch_gen velox_functions_json)
target_link_libraries(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please revert this formatting change.

presto_types presto_type_converter velox_type_fbhive
velox_hive_partition_function velox_tpch_gen velox_functions_json)

if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
add_dependencies(presto_types velox_expression presto_server_remote_function
velox_functions_remote)
target_link_libraries(presto_types presto_server_remote_function
velox_functions_remote)
endif()

set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool)

Expand Down
127 changes: 127 additions & 0 deletions presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@
#include "velox/vector/ComplexVector.h"
#include "velox/vector/ConstantVector.h"
#include "velox/vector/FlatVector.h"
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
#include "presto_cpp/main/JsonSignatureParser.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/functions/remote/client/Remote.h"
#endif

using namespace facebook::velox::core;
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
using facebook::velox::functions::remote::PageFormat;
#endif
using facebook::velox::TypeKind;

namespace facebook::presto {
Expand Down Expand Up @@ -131,6 +139,114 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) {
: functionId;
}

#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
std::string getSchemaName(const protocol::SqlFunctionId& functionId) {
// Example: "json.x4.eq;INTEGER;INTEGER".
const auto nameEnd = functionId.find(';');
std::string functionName = (nameEnd != std::string::npos)
? functionId.substr(0, nameEnd)
: functionId;

const auto firstDot = functionName.find('.');
const auto secondDot = functionName.find('.', firstDot + 1);
if (firstDot != std::string::npos && secondDot != std::string::npos) {
return functionName.substr(firstDot + 1, secondDot - firstDot - 1);
}

return "";
}

std::string extractFunctionName(const std::string& input) {
size_t lastDot = input.find_last_of('.');
if (lastDot != std::string::npos) {
return input.substr(lastDot + 1);
}
return input;
}

std::string urlEncode(const std::string& value) {
std::ostringstream escaped;
escaped.fill('0');
escaped << std::hex;
for (char c : value) {
if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' ||
c == '.' || c == '~') {
escaped << c;
} else {
escaped << '%' << std::setw(2) << int(static_cast<unsigned char>(c));
}
}
return escaped.str();
}

TypedExprPtr registerRestRemoteFunction(
const protocol::RestFunctionHandle& restFunctionHandle,
const std::vector<TypedExprPtr>& args,
const velox::TypePtr& returnType) {
const auto* systemConfig = SystemConfig::instance();

velox::functions::RemoteVectorFunctionMetadata metadata;
const auto& serdeName = systemConfig->remoteFunctionServerSerde();
if (serdeName == "presto_page") {
metadata.serdeFormat = PageFormat::PRESTO_PAGE;
} else {
VELOX_FAIL(
"presto_page serde is expected by remote function server but got : '{}'",
serdeName);
}
metadata.functionId = restFunctionHandle.functionId;
metadata.version = restFunctionHandle.version;
metadata.schema = getSchemaName(restFunctionHandle.functionId);

const std::string location = fmt::format(
"{}/v1/functions/{}/{}/{}/{}",
systemConfig->remoteFunctionServerRestURL(),
metadata.schema.value_or("default"),
extractFunctionName(getFunctionName(restFunctionHandle.functionId)),
urlEncode(restFunctionHandle.functionId),
restFunctionHandle.version);
metadata.location = location;

const auto& prestoSignature = restFunctionHandle.signature;
// parseTypeSignature
velox::exec::FunctionSignatureBuilder signatureBuilder;
// Handle type variable constraints
for (const auto& typeVar : prestoSignature.typeVariableConstraints) {
signatureBuilder.typeVariable(typeVar.name);
}

// Handle long variable constraints (for integer variables)
for (const auto& longVar : prestoSignature.longVariableConstraints) {
signatureBuilder.integerVariable(longVar.name);
}

// Handle return type
signatureBuilder.returnType(prestoSignature.returnType);

// Handle argument types
for (const auto& argType : prestoSignature.argumentTypes) {
signatureBuilder.argumentType(argType);
}

// Handle variable arity
if (prestoSignature.variableArity) {
signatureBuilder.variableArity();
}

auto signature = signatureBuilder.build();
std::vector<velox::exec::FunctionSignaturePtr> veloxSignatures = {signature};

velox::functions::registerRemoteFunction(
getFunctionName(restFunctionHandle.functionId),
veloxSignatures,
metadata,
false);

return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(restFunctionHandle.functionId));
}
#endif

} // namespace

velox::variant VeloxExprConverter::getConstantValue(
Expand Down Expand Up @@ -513,6 +629,17 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(sqlFunctionHandle->functionId));
}
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this logic to a separate function, say registerVeloxRemoteFunction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I have made the improvements

else if (
auto restFunctionHandle =
std::dynamic_pointer_cast<protocol::RestFunctionHandle>(
pexpr.functionHandle)) {
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

return registerRestRemoteFunction(*restFunctionHandle, args, returnType);
}
#endif

VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class VeloxExprConverter {
std::optional<velox::core::TypedExprPtr> tryConvertDate(
const protocol::CallExpression& pexpr) const;

#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
velox::core::TypedExprPtr registerRestRemoteFunction(
const protocol::RestFunctionHandle& restFunctionHandle,
const std::vector<velox::core::TypedExprPtr>& args,
const velox::TypePtr& returnType) const;
#endif

velox::memory::MemoryPool* const pool_;
TypeParser* const typeParser_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,29 @@ target_link_libraries(
velox_exec_test_lib
GTest::gtest
GTest::gtest_main)

# PrestoToVeloxExprTest.cpp only contains tests cases related to
# RestFunctionHandle, therefore it is only enabled when remote functions are
# enabled.
if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
add_executable(presto_to_velox_expr_test PrestoToVeloxExprTest.cpp)

add_test(
NAME presto_to_velox_expr_test
COMMAND presto_to_velox_expr_test
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})

add_dependencies(presto_to_velox_expr_test presto_server_test)

target_link_libraries(
presto_to_velox_expr_test
presto_protocol
presto_operators
presto_type_converter
presto_types
velox_dwio_common
velox_hive_connector
velox_tpch_connector
GTest::gtest
GTest::gtest_main)
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "presto_cpp/main/types/PrestoToVeloxExpr.h"
#include <gtest/gtest.h>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/remote/client/Remote.h"
#include "velox/functions/remote/server/RemoteFunctionService.h"

using namespace facebook::presto;
using namespace facebook::velox;

class RemoteFunctionTest : public ::testing::Test {
protected:
static void SetUpTestCase() {
memory::MemoryManager::testingSetInstance({});
}

void SetUp() override {
memoryPool_ = memory::MemoryManager::getInstance()->addLeafPool();
converter_ =
std::make_unique<VeloxExprConverter>(memoryPool_.get(), &typeParser_);

functionHandle = std::make_shared<protocol::RestFunctionHandle>();
functionHandle->functionId = "remote.testSchema.testFunction;BIGINT;BIGINT";
functionHandle->version = "v1";
functionHandle->signature.name = "testFunction";
functionHandle->signature.returnType = "bigint";
functionHandle->signature.argumentTypes = {"bigint", "bigint"};
functionHandle->signature.typeVariableConstraints = {};
functionHandle->signature.longVariableConstraints = {};
functionHandle->signature.variableArity = false;

expectedMetadata.serdeFormat = functions::remote::PageFormat::PRESTO_PAGE;
expectedMetadata.functionId = functionHandle->functionId;
expectedMetadata.version = functionHandle->version;
expectedMetadata.schema = "testSchema";

testExpr.functionHandle = functionHandle;
testExpr.returnType = "bigint";
testExpr.displayName = "testFunction";
auto cexpr = std::make_shared<protocol::ConstantExpression>();
cexpr->type = "bigint";
cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
testExpr.arguments.push_back(cexpr);

auto cexpr2 = std::make_shared<protocol::ConstantExpression>();
cexpr2->type = "bigint";
cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA";
testExpr.arguments.push_back(cexpr2);
}

std::unique_ptr<config::ConfigBase> restSystemConfig(
const std::unordered_map<std::string, std::string> configOverride = {})
const {
std::unordered_map<std::string, std::string> systemConfig{
{std::string(SystemConfig::kRemoteFunctionServerSerde),
std::string("presto_page")},
{std::string(SystemConfig::kRemoteFunctionServerRestURL),
std::string("http://localhost:8080")}};

for (const auto& [configName, configValue] : configOverride) {
systemConfig[configName] = configValue;
}
return std::make_unique<config::ConfigBase>(std::move(systemConfig), true);
}

std::shared_ptr<protocol::RestFunctionHandle> functionHandle;
protocol::CallExpression testExpr;
functions::RemoteVectorFunctionMetadata expectedMetadata;
std::shared_ptr<memory::MemoryPool> memoryPool_;
TypeParser typeParser_;
std::unique_ptr<VeloxExprConverter> converter_;
};

TEST_F(RemoteFunctionTest, HandlesRestFunctionCorrectly) {
try {
auto restConfig = restSystemConfig();
auto systemConfig = SystemConfig::instance();
systemConfig->initialize(std::move(restConfig));
auto expr = converter_->toVeloxExpr(testExpr);
auto callExpr = std::dynamic_pointer_cast<const core::CallTypedExpr>(expr);
ASSERT_NE(callExpr, nullptr);
EXPECT_EQ(callExpr->name(), "remote.testSchema.testFunction");

EXPECT_EQ(callExpr->inputs().size(), 2);
auto arg0 = std::dynamic_pointer_cast<const core::ConstantTypedExpr>(
callExpr->inputs()[0]);
auto arg1 = std::dynamic_pointer_cast<const core::ConstantTypedExpr>(
callExpr->inputs()[1]);
ASSERT_NE(arg0, nullptr);
ASSERT_NE(arg1, nullptr);
EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT);
EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT);

} catch (const std::exception& e) {
FAIL() << "Exception: " << e.what();
}
}

TEST_F(RemoteFunctionTest, UnsupportedSerdeFormat) {
std::unordered_map<std::string, std::string> restConfigOverride{
{std::string(SystemConfig::kRemoteFunctionServerSerde),
std::string("spark_unsafe_rows")}};
auto restConfig = restSystemConfig(restConfigOverride);
auto systemConfig = SystemConfig::instance();
systemConfig->initialize(std::move(restConfig));

VELOX_ASSERT_THROW(
converter_->toVeloxExpr(testExpr),
"presto_page serde is expected by remote function server but got : 'spark_unsafe_rows'");
}
Loading
Loading