From 992989212e840d673804a14dd41e7ff14eccd36b Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Mon, 7 Oct 2024 17:42:06 +0530 Subject: [PATCH] Add support for REST based remote functions. Co-authored-by: Wills Feng --- presto-native-execution/CMakeLists.txt | 2 +- .../presto_cpp/main/PrestoServer.cpp | 5 +- .../presto_cpp/main/common/Configs.cpp | 4 + .../presto_cpp/main/common/Configs.h | 6 ++ .../presto_cpp/main/types/CMakeLists.txt | 12 ++- .../main/types/PrestoToVeloxExpr.cpp | 75 +++++++++++++++++++ .../presto_cpp/main/types/PrestoToVeloxExpr.h | 7 ++ .../core/presto_protocol_core.cpp | 57 ++++++++++++++ .../core/presto_protocol_core.h | 11 +++ .../core/presto_protocol_core.yml | 2 + .../presto/spi/function/Signature.java | 7 ++ 11 files changed, 183 insertions(+), 5 deletions(-) diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index d5001dde70a74..1556e97c0fc45 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -57,7 +57,7 @@ option(PRESTO_ENABLE_ABFS "Build ABFS support" OFF) option(PRESTO_ENABLE_PARQUET "Enable Parquet support" OFF) # Forwards user input to VELOX_ENABLE_REMOTE_FUNCTIONS. -option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) +option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" ON) option(PRESTO_ENABLE_TESTING "Enable tests" ON) diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 0a422c66e16bd..c694af58b3f10 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1342,10 +1342,11 @@ void PrestoServer::registerRemoteFunctions() { } else { VELOX_FAIL( "To register remote functions using a json file path you need to " - "specify the remote server location using '{}', '{}' or '{}'.", + "specify the remote server location using '{}', '{}' or '{}' or {}.", SystemConfig::kRemoteFunctionServerThriftAddress, SystemConfig::kRemoteFunctionServerThriftPort, - SystemConfig::kRemoteFunctionServerThriftUdsPath); + SystemConfig::kRemoteFunctionServerThriftUdsPath, + SystemConfig::kRemoteFunctionServerRestURL); } } #endif diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 2ef3d8253ab3d..7f4cff8592cb5 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -342,6 +342,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const { return optionalProperty(kRemoteFunctionServerSerde).value(); } +std::string SystemConfig::remoteFunctionRestUrl() const { + return optionalProperty(kRemoteFunctionServerRestURL).value(); +} + int32_t SystemConfig::maxDriversPerTask() const { return optionalProperty(kMaxDriversPerTask).value(); } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 5969006e62575..77c965c5a4750 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -617,6 +617,10 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{ "remote-function-server.thrift.uds-path"}; + /// HTTP URL used by the remote function rest server. + static constexpr std::string_view kRemoteFunctionServerRestURL{ + "remote-function-server.rest.url"}; + /// Path where json files containing signatures for remote functions can be /// found. static constexpr std::string_view @@ -714,6 +718,8 @@ class SystemConfig : public ConfigBase { std::string remoteFunctionServerSerde() const; + std::string remoteFunctionRestUrl() const; + int32_t maxDriversPerTask() const; folly::Optional taskWriterCount() const; diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index 5841728512238..bc6438563028f 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -21,8 +21,16 @@ add_library( add_dependencies(presto_types presto_operators presto_type_converter velox_type velox_type_fbhive) -target_link_libraries(presto_types presto_type_converter velox_type_fbhive - velox_hive_partition_function velox_tpch_gen velox_functions_json) +target_link_libraries( + presto_types presto_type_converter velox_type_fbhive + velox_hive_partition_function velox_tpch_gen velox_functions_json) + +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_dependencies(presto_types velox_expression presto_server_remote_function + velox_functions_remote) + target_link_libraries(presto_types presto_server_remote_function + velox_functions_remote) +endif() set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index fe9731c116f4f..b5b42001c7525 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -21,8 +21,16 @@ #include "velox/vector/ComplexVector.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +#include "presto_cpp/main/JsonSignatureParser.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/remote/client/Remote.h" +#endif using namespace facebook::velox::core; +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +using facebook::velox::functions::remote::PageFormat; +#endif using facebook::velox::TypeKind; namespace facebook::presto { @@ -127,6 +135,61 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) { } // namespace +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +TypedExprPtr VeloxExprConverter::registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle, + const std::vector& args, + const velox::TypePtr& returnType) const { + const auto* systemConfig = SystemConfig::instance(); + + velox::functions::RemoteVectorFunctionMetadata metadata; + const auto& serdeName = systemConfig->remoteFunctionServerSerde(); + if (serdeName == "presto_page") { + metadata.serdeFormat = PageFormat::PRESTO_PAGE; + } else { + VELOX_FAIL( + "presto_page serde is expected by remote function server but got : '{}'", + serdeName); + } + metadata.location = systemConfig->remoteFunctionRestUrl(); + metadata.functionId = restFunctionHandle.functionId; + metadata.version = restFunctionHandle.version; + + const auto& prestoSignature = restFunctionHandle.signature; + velox::exec::FunctionSignatureBuilder signatureBuilder; + + for (const auto& typeVar : prestoSignature.typeVariableConstraints) { + signatureBuilder.typeVariable(typeVar.name); + } + + for (const auto& longVar : prestoSignature.longVariableConstraints) { + signatureBuilder.integerVariable(longVar.name); + } + + signatureBuilder.returnType(prestoSignature.returnType); + + for (const auto& argType : prestoSignature.argumentTypes) { + signatureBuilder.argumentType(argType); + } + + if (prestoSignature.variableArity) { + signatureBuilder.variableArity(); + } + + auto signature = signatureBuilder.build(); + std::vector veloxSignatures = {signature}; + + velox::functions::registerRemoteFunction( + getFunctionName(restFunctionHandle.functionId), + veloxSignatures, + metadata, + false); + + return std::make_shared( + returnType, args, getFunctionName(restFunctionHandle.functionId)); +} +#endif + velox::variant VeloxExprConverter::getConstantValue( const velox::TypePtr& type, const protocol::Block& block) const { @@ -504,10 +567,22 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( pexpr.functionHandle)) { auto args = toVeloxExpr(pexpr.arguments); auto returnType = typeParser_->parse(pexpr.returnType); + return std::make_shared( returnType, args, getFunctionName(sqlFunctionHandle->functionId)); } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + else if ( + auto restFunctionHandle = + std::dynamic_pointer_cast( + pexpr.functionHandle)) { + // Defer to our new helper function for restFunctionHandle. + auto args = toVeloxExpr(pexpr.arguments); + auto returnType = typeParser_->parse(pexpr.returnType); + return registerRestRemoteFunction(*restFunctionHandle, args, returnType); + } +#endif VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type); } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h index 8526f6a8c9638..8663d4a527166 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.h @@ -64,6 +64,13 @@ class VeloxExprConverter { std::optional tryConvertDate( const protocol::CallExpression& pexpr) const; +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + velox::core::TypedExprPtr registerRestRemoteFunction( + const protocol::RestFunctionHandle& restFunctionHandle, + const std::vector& args, + const velox::TypePtr& returnType) const; +#endif + velox::memory::MemoryPool* const pool_; TypeParser* const typeParser_; }; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index b0ecb4a22354a..c8e63f50f7b6e 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -116,6 +116,10 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == "rest") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -149,6 +153,13 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == "rest") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -8356,6 +8367,52 @@ void from_json(const json& j, RemoteTransactionHandle& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +RestFunctionHandle::RestFunctionHandle() noexcept { + _type = "rest"; +} + +void to_json(json& j, const RestFunctionHandle& p) { + j = json::object(); + j["@type"] = "rest"; + to_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + to_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + to_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} + +void from_json(const json& j, RestFunctionHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + from_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + from_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { RowNumberNode::RowNumberNode() noexcept { _type = "com.facebook.presto.sql.planner.plan.RowNumberNode"; } diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 907f160650cdb..f52d1be9ca0ee 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -1947,6 +1947,17 @@ void to_json(json& j, const RemoteTransactionHandle& p); void from_json(const json& j, RemoteTransactionHandle& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct RestFunctionHandle : public FunctionHandle { + SqlFunctionId functionId = {}; + String version = {}; + Signature signature = {}; + + RestFunctionHandle() noexcept; +}; +void to_json(json& j, const RestFunctionHandle& p); +void from_json(const json& j, RestFunctionHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct RowNumberNode : public PlanNode { std::shared_ptr source = {}; List partitionBy = {}; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index 32feb37fd501e..e160df972c0b8 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -176,6 +176,7 @@ AbstractClasses: - { name: BuiltInFunctionHandle, key: $static } - { name: SqlFunctionHandle, key: native } - { name: SqlFunctionHandle, key: json_file } + - { name: RestFunctionHandle, key: rest } JavaClasses: @@ -192,6 +193,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/execution/buffer/BufferState.java - presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionHandle.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/RestFunctionHandle.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java - presto-spi/src/main/java/com/facebook/presto/spi/relation/CallExpression.java diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java index 024ba43035bd0..bf949e40da0d7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java @@ -35,12 +35,19 @@ @ThriftStruct public final class Signature { + @JsonProperty("name") private final QualifiedObjectName name; + @JsonProperty("kind") private final FunctionKind kind; + @JsonProperty("typeVariableConstraints") private final List typeVariableConstraints; + @JsonProperty("longVariableConstraints") private final List longVariableConstraints; + @JsonProperty("returnType") private final TypeSignature returnType; + @JsonProperty("argumentTypes") private final List argumentTypes; + @JsonProperty("variableArity") private final boolean variableArity; @ThriftConstructor