diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 0a422c66e16bd..c694af58b3f10 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1342,10 +1342,11 @@ void PrestoServer::registerRemoteFunctions() { } else { VELOX_FAIL( "To register remote functions using a json file path you need to " - "specify the remote server location using '{}', '{}' or '{}'.", + "specify the remote server location using '{}', '{}' or '{}' or {}.", SystemConfig::kRemoteFunctionServerThriftAddress, SystemConfig::kRemoteFunctionServerThriftPort, - SystemConfig::kRemoteFunctionServerThriftUdsPath); + SystemConfig::kRemoteFunctionServerThriftUdsPath, + SystemConfig::kRemoteFunctionServerRestURL); } } #endif diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 2ef3d8253ab3d..60e3bafef4ec0 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -342,6 +342,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const { return optionalProperty(kRemoteFunctionServerSerde).value(); } +std::string SystemConfig::remoteFunctionServerRestURL() const { + return optionalProperty(kRemoteFunctionServerRestURL).value(); +} + int32_t SystemConfig::maxDriversPerTask() const { return optionalProperty(kMaxDriversPerTask).value(); } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 5969006e62575..82e654cdea792 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -617,6 +617,10 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{ "remote-function-server.thrift.uds-path"}; + /// HTTP URL used by the remote function rest server. + static constexpr std::string_view kRemoteFunctionServerRestURL{ + "remote-function-server.rest.url"}; + /// Path where json files containing signatures for remote functions can be /// found. static constexpr std::string_view @@ -714,6 +718,8 @@ class SystemConfig : public ConfigBase { std::string remoteFunctionServerSerde() const; + std::string remoteFunctionServerRestURL() const; + int32_t maxDriversPerTask() const; folly::Optional taskWriterCount() const; diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 5841728512238..bc6438563028f 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -21,8 +21,16 @@ add_library( add_dependencies(presto_types presto_operators presto_type_converter velox_type velox_type_fbhive) -target_link_libraries(presto_types presto_type_converter velox_type_fbhive - velox_hive_partition_function velox_tpch_gen velox_functions_json) +target_link_libraries( + presto_types presto_type_converter velox_type_fbhive + velox_hive_partition_function velox_tpch_gen velox_functions_json) + +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_dependencies(presto_types velox_expression presto_server_remote_function + velox_functions_remote) + target_link_libraries(presto_types presto_server_remote_function + velox_functions_remote) +endif() set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 83d359f6d6aad..6d61378fa7d42 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -21,8 +21,16 @@ #include "velox/vector/ComplexVector.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +#include "presto_cpp/main/JsonSignatureParser.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/remote/client/Remote.h" +#endif using namespace facebook::velox::core; +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +using facebook::velox::functions::remote::PageFormat; +#endif using facebook::velox::TypeKind; namespace facebook::presto { @@ -131,6 +139,114 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { : functionId; } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +std::string getSchemaName(const protocol::SqlFunctionId& functionId) { + // Example: "json.x4.eq;INTEGER;INTEGER". + const auto nameEnd = functionId.find(';'); + std::string functionName = (nameEnd != std::string::npos) + ? functionId.substr(0, nameEnd) + : functionId; + + const auto firstDot = functionName.find('.'); + const auto secondDot = functionName.find('.', firstDot + 1); + if (firstDot != std::string::npos && secondDot != std::string::npos) { + return functionName.substr(firstDot + 1, secondDot - firstDot - 1); + } + + return ""; +} + +std::string extractFunctionName(const std::string& input) { + size_t lastDot = input.find_last_of('.'); + if (lastDot != std::string::npos) { + return input.substr(lastDot + 1); + } + return input; +} + +std::string urlEncode(const std::string& value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + for (char c : value) { + if (isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '~') { + escaped << c; + } else { + escaped << '%' << std::setw(2) << int(static_cast(c)); + } + } + return escaped.str(); +} + +TypedExprPtr registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle, + const std::vector& args, + const velox::TypePtr& returnType) { + const auto* systemConfig = SystemConfig::instance(); + + velox::functions::RemoteVectorFunctionMetadata metadata; + const auto& serdeName = systemConfig->remoteFunctionServerSerde(); + if (serdeName == "presto_page") { + metadata.serdeFormat = PageFormat::PRESTO_PAGE; + } else { + VELOX_FAIL( + "presto_page serde is expected by remote function server but got : '{}'", + serdeName); + } + metadata.functionId = restFunctionHandle.functionId; + metadata.version = restFunctionHandle.version; + metadata.schema = getSchemaName(restFunctionHandle.functionId); + + const std::string location = fmt::format( + "{}/v1/functions/{}/{}/{}/{}", + systemConfig->remoteFunctionServerRestURL(), + metadata.schema.value_or("default"), + extractFunctionName(getFunctionName(restFunctionHandle.functionId)), + urlEncode(restFunctionHandle.functionId), + restFunctionHandle.version); + metadata.location = location; + + const auto& prestoSignature = restFunctionHandle.signature; + // parseTypeSignature + velox::exec::FunctionSignatureBuilder signatureBuilder; + // Handle type variable constraints + for (const auto& typeVar : prestoSignature.typeVariableConstraints) { + signatureBuilder.typeVariable(typeVar.name); + } + + // Handle long variable constraints (for integer variables) + for (const auto& longVar : prestoSignature.longVariableConstraints) { + signatureBuilder.integerVariable(longVar.name); + } + + // Handle return type + signatureBuilder.returnType(prestoSignature.returnType); + + // Handle argument types + for (const auto& argType : prestoSignature.argumentTypes) { + signatureBuilder.argumentType(argType); + } + + // Handle variable arity + if (prestoSignature.variableArity) { + signatureBuilder.variableArity(); + } + + auto signature = signatureBuilder.build(); + std::vector veloxSignatures = {signature}; + + velox::functions::registerRemoteFunction( + getFunctionName(restFunctionHandle.functionId), + veloxSignatures, + metadata, + false); + + return std::make_shared( + returnType, args, getFunctionName(restFunctionHandle.functionId)); +} +#endif + } // namespace velox::variant VeloxExprConverter::getConstantValue( @@ -513,6 +629,17 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( return std::make_shared( returnType, args, getFunctionName(sqlFunctionHandle->functionId)); } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + else if ( + auto restFunctionHandle = + std::dynamic_pointer_cast( + pexpr.functionHandle)) { + auto args = toVeloxExpr(pexpr.arguments); + auto returnType = typeParser_->parse(pexpr.returnType); + + return registerRestRemoteFunction(*restFunctionHandle, args, returnType); + } +#endif VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type); } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index 8526f6a8c9638..8663d4a527166 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -64,6 +64,13 @@ class VeloxExprConverter { std::optional tryConvertDate( const protocol::CallExpression& pexpr) const; +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + velox::core::TypedExprPtr registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle, + const std::vector& args, + const velox::TypePtr& returnType) const; +#endif + velox::memory::MemoryPool* const pool_; TypeParser* const typeParser_; }; diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index 28f73aff40b80..ec7b7d6441a64 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -131,3 +131,29 @@ target_link_libraries( velox_exec_test_lib GTest::gtest GTest::gtest_main) + +# PrestoToVeloxExprTest.cpp only contains tests cases related to +# RestFunctionHandle, therefore it is only enabled when remote functions are +# enabled. +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_executable(presto_to_velox_expr_test PrestoToVeloxExprTest.cpp) + + add_test( + NAME presto_to_velox_expr_test + COMMAND presto_to_velox_expr_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + + add_dependencies(presto_to_velox_expr_test presto_server_test) + + target_link_libraries( + presto_to_velox_expr_test + presto_protocol + presto_operators + presto_type_converter + presto_types + velox_dwio_common + velox_hive_connector + velox_tpch_connector + GTest::gtest + GTest::gtest_main) +endif() diff --git a/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxExprTest.cpp b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxExprTest.cpp new file mode 100644 index 0000000000000..70f289865cd15 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/types/tests/PrestoToVeloxExprTest.cpp @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/types/PrestoToVeloxExpr.h" +#include +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/server/RemoteFunctionService.h" + +using namespace facebook::presto; +using namespace facebook::velox; + +class RemoteFunctionTest : public ::testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + + void SetUp() override { + memoryPool_ = memory::MemoryManager::getInstance()->addLeafPool(); + converter_ = + std::make_unique(memoryPool_.get(), &typeParser_); + + functionHandle = std::make_shared(); + functionHandle->functionId = "remote.testSchema.testFunction;BIGINT;BIGINT"; + functionHandle->version = "v1"; + functionHandle->signature.name = "testFunction"; + functionHandle->signature.returnType = "bigint"; + functionHandle->signature.argumentTypes = {"bigint", "bigint"}; + functionHandle->signature.typeVariableConstraints = {}; + functionHandle->signature.longVariableConstraints = {}; + functionHandle->signature.variableArity = false; + + expectedMetadata.serdeFormat = functions::remote::PageFormat::PRESTO_PAGE; + expectedMetadata.functionId = functionHandle->functionId; + expectedMetadata.version = functionHandle->version; + expectedMetadata.schema = "testSchema"; + + testExpr.functionHandle = functionHandle; + testExpr.returnType = "bigint"; + testExpr.displayName = "testFunction"; + auto cexpr = std::make_shared(); + cexpr->type = "bigint"; + cexpr->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + testExpr.arguments.push_back(cexpr); + + auto cexpr2 = std::make_shared(); + cexpr2->type = "bigint"; + cexpr2->valueBlock.data = "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA"; + testExpr.arguments.push_back(cexpr2); + } + + std::unique_ptr restSystemConfig( + const std::unordered_map configOverride = {}) + const { + std::unordered_map systemConfig{ + {std::string(SystemConfig::kRemoteFunctionServerSerde), + std::string("presto_page")}, + {std::string(SystemConfig::kRemoteFunctionServerRestURL), + std::string("http://localhost:8080")}}; + + for (const auto& [configName, configValue] : configOverride) { + systemConfig[configName] = configValue; + } + return std::make_unique(std::move(systemConfig), true); + } + + std::shared_ptr functionHandle; + protocol::CallExpression testExpr; + functions::RemoteVectorFunctionMetadata expectedMetadata; + std::shared_ptr memoryPool_; + TypeParser typeParser_; + std::unique_ptr converter_; +}; + +TEST_F(RemoteFunctionTest, HandlesRestFunctionCorrectly) { + try { + auto restConfig = restSystemConfig(); + auto systemConfig = SystemConfig::instance(); + systemConfig->initialize(std::move(restConfig)); + auto expr = converter_->toVeloxExpr(testExpr); + auto callExpr = std::dynamic_pointer_cast(expr); + ASSERT_NE(callExpr, nullptr); + EXPECT_EQ(callExpr->name(), "remote.testSchema.testFunction"); + + EXPECT_EQ(callExpr->inputs().size(), 2); + auto arg0 = std::dynamic_pointer_cast( + callExpr->inputs()[0]); + auto arg1 = std::dynamic_pointer_cast( + callExpr->inputs()[1]); + ASSERT_NE(arg0, nullptr); + ASSERT_NE(arg1, nullptr); + EXPECT_EQ(arg0->type()->kind(), TypeKind::BIGINT); + EXPECT_EQ(arg1->type()->kind(), TypeKind::BIGINT); + + } catch (const std::exception& e) { + FAIL() << "Exception: " << e.what(); + } +} + +TEST_F(RemoteFunctionTest, UnsupportedSerdeFormat) { + std::unordered_map restConfigOverride{ + {std::string(SystemConfig::kRemoteFunctionServerSerde), + std::string("spark_unsafe_rows")}}; + auto restConfig = restSystemConfig(restConfigOverride); + auto systemConfig = SystemConfig::instance(); + systemConfig->initialize(std::move(restConfig)); + + VELOX_ASSERT_THROW( + converter_->toVeloxExpr(testExpr), + "presto_page serde is expected by remote function server but got : 'spark_unsafe_rows'"); +} diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index b0ecb4a22354a..c8e63f50f7b6e 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -116,6 +116,10 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == "rest") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -149,6 +153,13 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == "rest") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -8356,6 +8367,52 @@ void from_json(const json& j, RemoteTransactionHandle& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +RestFunctionHandle::RestFunctionHandle() noexcept { + _type = "rest"; +} + +void to_json(json& j, const RestFunctionHandle& p) { + j = json::object(); + j["@type"] = "rest"; + to_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + to_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + to_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} + +void from_json(const json& j, RestFunctionHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + from_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + from_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { RowNumberNode::RowNumberNode() noexcept { _type = "com.facebook.presto.sql.planner.plan.RowNumberNode"; } diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 907f160650cdb..f52d1be9ca0ee 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -1947,6 +1947,17 @@ void to_json(json& j, const RemoteTransactionHandle& p); void from_json(const json& j, RemoteTransactionHandle& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct RestFunctionHandle : public FunctionHandle { + SqlFunctionId functionId = {}; + String version = {}; + Signature signature = {}; + + RestFunctionHandle() noexcept; +}; +void to_json(json& j, const RestFunctionHandle& p); +void from_json(const json& j, RestFunctionHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct RowNumberNode : public PlanNode { std::shared_ptr source = {}; List partitionBy = {}; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 32feb37fd501e..e160df972c0b8 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -176,6 +176,7 @@ AbstractClasses: - { name: BuiltInFunctionHandle, key: $static } - { name: SqlFunctionHandle, key: native } - { name: SqlFunctionHandle, key: json_file } + - { name: RestFunctionHandle, key: rest } JavaClasses: @@ -192,6 +193,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/execution/buffer/BufferState.java - presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionHandle.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/RestFunctionHandle.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java - presto-spi/src/main/java/com/facebook/presto/spi/relation/CallExpression.java diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java index 024ba43035bd0..bf949e40da0d7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java @@ -35,12 +35,19 @@ @ThriftStruct public final class Signature { + @JsonProperty("name") private final QualifiedObjectName name; + @JsonProperty("kind") private final FunctionKind kind; + @JsonProperty("typeVariableConstraints") private final List typeVariableConstraints; + @JsonProperty("longVariableConstraints") private final List longVariableConstraints; + @JsonProperty("returnType") private final TypeSignature returnType; + @JsonProperty("argumentTypes") private final List argumentTypes; + @JsonProperty("variableArity") private final boolean variableArity; @ThriftConstructor