From 4356902da8ec2a562d9028d7ccda33067730c0ed Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Mon, 7 Oct 2024 17:42:06 +0530 Subject: [PATCH] Add support for REST based remote functions. Co-authored-by: Wills Feng --- .../presto_cpp/main/PrestoServer.cpp | 5 +- .../presto_cpp/main/common/Configs.cpp | 4 + .../presto_cpp/main/common/Configs.h | 6 ++ .../presto_cpp/main/types/CMakeLists.txt | 8 ++ .../main/types/PrestoToVeloxExpr.cpp | 79 +++++++++++++++++ .../main/types/tests/CMakeLists.txt | 10 +++ .../core/presto_protocol_core.cpp | 85 +++++++++++++++++++ .../core/presto_protocol_core.h | 21 ++++- .../core/presto_protocol_core.yml | 3 +- .../presto/spi/function/Signature.java | 7 ++ 10 files changed, 221 insertions(+), 7 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index a01eb58bdd178..f76493ba49ad6 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -1312,10 +1312,11 @@ void PrestoServer::registerRemoteFunctions() { } else { VELOX_FAIL( "To register remote functions using a json file path you need to " - "specify the remote server location using '{}', '{}' or '{}'.", + "specify the remote server location using '{}', '{}' or '{}' or {}.", SystemConfig::kRemoteFunctionServerThriftAddress, SystemConfig::kRemoteFunctionServerThriftPort, - SystemConfig::kRemoteFunctionServerThriftUdsPath); + SystemConfig::kRemoteFunctionServerThriftUdsPath, + SystemConfig::kRemoteFunctionServerRestURL); } } #endif diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 0929f43436a88..fa19332777504 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -341,6 +341,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const { return optionalProperty(kRemoteFunctionServerSerde).value(); } +std::string SystemConfig::remoteFunctionRestUrl() const { + return optionalProperty(kRemoteFunctionServerRestURL).value(); +} + int32_t SystemConfig::maxDriversPerTask() const { return optionalProperty(kMaxDriversPerTask).value(); } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 56f54e30f3d4e..c536504376052 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -615,6 +615,10 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{ "remote-function-server.thrift.uds-path"}; + /// HTTP URL used by the remote function rest server. + static constexpr std::string_view kRemoteFunctionServerRestURL{ + "remote-function-server.rest.url"}; + /// Path where json files containing signatures for remote functions can be /// found. static constexpr std::string_view @@ -708,6 +712,8 @@ class SystemConfig : public ConfigBase { std::string remoteFunctionServerSerde() const; + std::string remoteFunctionRestUrl() const; + int32_t maxDriversPerTask() const; folly::Optional taskWriterCount() const; diff --git a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt index c17e136e9984d..45e9a403e21bc 100644 --- a/presto-native-execution/presto_cpp/main/types/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/CMakeLists.txt @@ -18,12 +18,20 @@ add_library( presto_types OBJECT PrestoToVeloxQueryPlan.cpp PrestoToVeloxExpr.cpp VeloxPlanValidator.cpp PrestoToVeloxSplit.cpp PrestoToVeloxConnector.cpp) + add_dependencies(presto_types presto_operators presto_type_converter velox_type velox_type_fbhive) target_link_libraries(presto_types presto_type_converter velox_type_fbhive velox_hive_partition_function velox_tpch_gen) +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_dependencies(presto_types velox_expression presto_server_remote_function + velox_functions_remote) + target_link_libraries(presto_types presto_server_remote_function + velox_functions_remote) +endif() + set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool) add_library(presto_function_metadata OBJECT FunctionMetadata.cpp) diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp index 2c2b2a3c5ea00..b9b68c76faffc 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxExpr.cpp @@ -14,14 +14,23 @@ #include "presto_cpp/main/types/PrestoToVeloxExpr.h" #include +#include "presto_cpp/main/common/Configs.h" #include "presto_cpp/presto_protocol/Base64Util.h" #include "velox/common/base/Exceptions.h" #include "velox/functions/prestosql/types/JsonType.h" #include "velox/vector/ComplexVector.h" #include "velox/vector/ConstantVector.h" #include "velox/vector/FlatVector.h" +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +#include "presto_cpp/main/JsonSignatureParser.h" +#include "velox/expression/FunctionSignature.h" +#include "velox/functions/remote/client/Remote.h" +#endif using namespace facebook::velox::core; +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +using facebook::velox::functions::remote::PageFormat; +#endif using facebook::velox::TypeKind; namespace facebook::presto { @@ -412,6 +421,18 @@ std::optional VeloxExprConverter::tryConvertLike( returnType, args, getFunctionName(signature)); } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS +PageFormat fromSerdeString(const std::string_view& serdeName) { + if (serdeName == "presto_page") { + return PageFormat::PRESTO_PAGE; + } else { + VELOX_FAIL( + "presto_page serde is expected by remote function server but got : '{}'", + serdeName); + } +} +#endif + TypedExprPtr VeloxExprConverter::toVeloxExpr( const protocol::CallExpression& pexpr) const { if (auto builtin = std::dynamic_pointer_cast( @@ -458,10 +479,68 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr( pexpr.functionHandle)) { auto args = toVeloxExpr(pexpr.arguments); auto returnType = typeParser_->parse(pexpr.returnType); + return std::make_shared( returnType, args, getFunctionName(sqlFunctionHandle->functionId)); } +#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS + else if ( + auto restFunctionHandle = + std::dynamic_pointer_cast( + pexpr.functionHandle)) { + + auto args = toVeloxExpr(pexpr.arguments); + auto returnType = typeParser_->parse(pexpr.returnType); + + const auto* systemConfig = SystemConfig::instance(); + + velox::functions::RemoteVectorFunctionMetadata metadata; + metadata.serdeFormat = + fromSerdeString(systemConfig->remoteFunctionServerSerde()); + metadata.location = systemConfig->remoteFunctionRestUrl(); + metadata.functionId = restFunctionHandle->functionId; + metadata.version = restFunctionHandle->version; + + const auto& prestoSignature = restFunctionHandle->signature; + // parseTypeSignature + velox::exec::FunctionSignatureBuilder signatureBuilder; + // Handle type variable constraints + for (const auto& typeVar : prestoSignature.typeVariableConstraints) { + signatureBuilder.typeVariable(typeVar.name); + } + // Handle long variable constraints (for integer variables) + for (const auto& longVar : prestoSignature.longVariableConstraints) { + signatureBuilder.integerVariable(longVar.name); + } + + // Handle return type + signatureBuilder.returnType(prestoSignature.returnType); + + // Handle argument types + for (const auto& argType : prestoSignature.argumentTypes) { + signatureBuilder.argumentType(argType); + } + + // Handle variable arity + if (prestoSignature.variableArity) { + signatureBuilder.variableArity(); + } + + auto signature = signatureBuilder.build(); + std::vector veloxSignatures = { + signature}; + + velox::functions::registerRemoteFunction( + getFunctionName(restFunctionHandle->functionId), + veloxSignatures, + metadata, + false); + + return std::make_shared( + returnType, args, getFunctionName(restFunctionHandle->functionId)); + } +#endif VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type); } diff --git a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt index caa5943ec3d79..cef583d0c93bc 100644 --- a/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/types/tests/CMakeLists.txt @@ -72,6 +72,16 @@ target_link_libraries( ${GFLAGS_LIBRARIES} pthread) +if(PRESTO_ENABLE_REMOTE_FUNCTIONS) + add_dependencies(presto_expressions_test presto_server_remote_function + velox_expression velox_functions_remote) + + target_link_libraries( + presto_expressions_test GTest::gmock GTest::gmock_main + presto_server_remote_function velox_expression velox_functions_remote) + +endif() + set_property(TARGET presto_expressions_test PROPERTY JOB_POOL_LINK presto_link_job_pool) diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 60b9a0a601e60..6dda5d207c0cb 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -112,6 +112,10 @@ void to_json(json& j, const std::shared_ptr& p) { j = *std::static_pointer_cast(p); return; } + if (type == "rest") { + j = *std::static_pointer_cast(p); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -138,6 +142,13 @@ void from_json(const json& j, std::shared_ptr& p) { p = std::static_pointer_cast(k); return; } + if (type == "rest") { + std::shared_ptr k = + std::make_shared(); + j.get_to(*k); + p = std::static_pointer_cast(k); + return; + } throw TypeError(type + " no abstract type FunctionHandle "); } @@ -5849,6 +5860,20 @@ void to_json(json& j, const JsonBasedUdfFunctionMetadata& p) { "JsonBasedUdfFunctionMetadata", "AggregationFunctionMetadata", "aggregateMetadata"); + to_json_key( + j, + "functionId", + p.functionId, + "JsonBasedUdfFunctionMetadata", + "SqlFunctionId", + "functionId"); + to_json_key( + j, + "version", + p.version, + "JsonBasedUdfFunctionMetadata", + "String", + "version"); } void from_json(const json& j, JsonBasedUdfFunctionMetadata& p) { @@ -5901,6 +5926,20 @@ void from_json(const json& j, JsonBasedUdfFunctionMetadata& p) { "JsonBasedUdfFunctionMetadata", "AggregationFunctionMetadata", "aggregateMetadata"); + from_json_key( + j, + "functionId", + p.functionId, + "JsonBasedUdfFunctionMetadata", + "SqlFunctionId", + "functionId"); + from_json_key( + j, + "version", + p.version, + "JsonBasedUdfFunctionMetadata", + "String", + "version"); } } // namespace facebook::presto::protocol /* @@ -8261,6 +8300,52 @@ void from_json(const json& j, RemoteTransactionHandle& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +RestFunctionHandle::RestFunctionHandle() noexcept { + _type = "rest"; +} + +void to_json(json& j, const RestFunctionHandle& p) { + j = json::object(); + j["@type"] = "rest"; + to_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + to_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + to_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} + +void from_json(const json& j, RestFunctionHandle& p) { + p._type = j["@type"]; + from_json_key( + j, + "functionId", + p.functionId, + "RestFunctionHandle", + "SqlFunctionId", + "functionId"); + from_json_key( + j, "version", p.version, "RestFunctionHandle", "String", "version"); + from_json_key( + j, + "signature", + p.signature, + "RestFunctionHandle", + "Signature", + "signature"); +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { RowNumberNode::RowNumberNode() noexcept { _type = "com.facebook.presto.sql.planner.plan.RowNumberNode"; } diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 902c0c24ce5a0..665761c381e36 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -67,21 +67,21 @@ extern const char* const PRESTO_ABORT_TASK_URL_PARAM; class Exception : public std::runtime_error { public: explicit Exception(const std::string& message) - : std::runtime_error(message){}; + : std::runtime_error(message) {}; }; class TypeError : public Exception { public: - explicit TypeError(const std::string& message) : Exception(message){}; + explicit TypeError(const std::string& message) : Exception(message) {}; }; class OutOfRange : public Exception { public: - explicit OutOfRange(const std::string& message) : Exception(message){}; + explicit OutOfRange(const std::string& message) : Exception(message) {}; }; class ParseError : public Exception { public: - explicit ParseError(const std::string& message) : Exception(message){}; + explicit ParseError(const std::string& message) : Exception(message) {}; }; using String = std::string; @@ -1508,6 +1508,8 @@ struct JsonBasedUdfFunctionMetadata { String schema = {}; RoutineCharacteristics routineCharacteristics = {}; std::shared_ptr aggregateMetadata = {}; + std::shared_ptr functionId = {}; + std::shared_ptr version = {}; }; void to_json(json& j, const JsonBasedUdfFunctionMetadata& p); void from_json(const json& j, JsonBasedUdfFunctionMetadata& p); @@ -1941,6 +1943,17 @@ void to_json(json& j, const RemoteTransactionHandle& p); void from_json(const json& j, RemoteTransactionHandle& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +struct RestFunctionHandle : public FunctionHandle { + SqlFunctionId functionId = {}; + String version = {}; + Signature signature = {}; + + RestFunctionHandle() noexcept; +}; +void to_json(json& j, const RestFunctionHandle& p); +void from_json(const json& j, RestFunctionHandle& p); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct RowNumberNode : public PlanNode { std::shared_ptr source = {}; List partitionBy = {}; diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml index ae0222da2e67a..60a8826850f3c 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.yml @@ -175,7 +175,7 @@ AbstractClasses: subclasses: - { name: BuiltInFunctionHandle, key: $static } - { name: SqlFunctionHandle, key: json_file } - + - { name: RestFunctionHandle, key: rest } JavaClasses: - presto-spi/src/main/java/com/facebook/presto/spi/ErrorCause.java @@ -191,6 +191,7 @@ JavaClasses: - presto-main/src/main/java/com/facebook/presto/execution/buffer/BufferState.java - presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java - presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionHandle.java + - presto-spi/src/main/java/com/facebook/presto/spi/function/RestFunctionHandle.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java - presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java - presto-spi/src/main/java/com/facebook/presto/spi/relation/CallExpression.java diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java index 024ba43035bd0..bf949e40da0d7 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java @@ -35,12 +35,19 @@ @ThriftStruct public final class Signature { + @JsonProperty("name") private final QualifiedObjectName name; + @JsonProperty("kind") private final FunctionKind kind; + @JsonProperty("typeVariableConstraints") private final List typeVariableConstraints; + @JsonProperty("longVariableConstraints") private final List longVariableConstraints; + @JsonProperty("returnType") private final TypeSignature returnType; + @JsonProperty("argumentTypes") private final List argumentTypes; + @JsonProperty("variableArity") private final boolean variableArity; @ThriftConstructor