Skip to content

Commit

Permalink
Add support for REST based remote functions.
Browse files Browse the repository at this point in the history
Co-authored-by: Wills Feng <[email protected]>
  • Loading branch information
Joe-Abraham and wills-feng committed Feb 3, 2025
1 parent a7a7f80 commit 9929892
Show file tree
Hide file tree
Showing 11 changed files with 183 additions and 5 deletions.
2 changes: 1 addition & 1 deletion presto-native-execution/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ option(PRESTO_ENABLE_ABFS "Build ABFS support" OFF)
option(PRESTO_ENABLE_PARQUET "Enable Parquet support" OFF)

# Forwards user input to VELOX_ENABLE_REMOTE_FUNCTIONS.
option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF)
option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" ON)

option(PRESTO_ENABLE_TESTING "Enable tests" ON)

Expand Down
5 changes: 3 additions & 2 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1342,10 +1342,11 @@ void PrestoServer::registerRemoteFunctions() {
} else {
VELOX_FAIL(
"To register remote functions using a json file path you need to "
"specify the remote server location using '{}', '{}' or '{}'.",
"specify the remote server location using '{}', '{}' or '{}' or {}.",
SystemConfig::kRemoteFunctionServerThriftAddress,
SystemConfig::kRemoteFunctionServerThriftPort,
SystemConfig::kRemoteFunctionServerThriftUdsPath);
SystemConfig::kRemoteFunctionServerThriftUdsPath,
SystemConfig::kRemoteFunctionServerRestURL);
}
}
#endif
Expand Down
4 changes: 4 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,10 @@ std::string SystemConfig::remoteFunctionServerSerde() const {
return optionalProperty(kRemoteFunctionServerSerde).value();
}

std::string SystemConfig::remoteFunctionRestUrl() const {
return optionalProperty(kRemoteFunctionServerRestURL).value();
}

int32_t SystemConfig::maxDriversPerTask() const {
return optionalProperty<int32_t>(kMaxDriversPerTask).value();
}
Expand Down
6 changes: 6 additions & 0 deletions presto-native-execution/presto_cpp/main/common/Configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,10 @@ class SystemConfig : public ConfigBase {
static constexpr std::string_view kRemoteFunctionServerThriftUdsPath{
"remote-function-server.thrift.uds-path"};

/// HTTP URL used by the remote function rest server.
static constexpr std::string_view kRemoteFunctionServerRestURL{
"remote-function-server.rest.url"};

/// Path where json files containing signatures for remote functions can be
/// found.
static constexpr std::string_view
Expand Down Expand Up @@ -714,6 +718,8 @@ class SystemConfig : public ConfigBase {

std::string remoteFunctionServerSerde() const;

std::string remoteFunctionRestUrl() const;

int32_t maxDriversPerTask() const;

folly::Optional<int32_t> taskWriterCount() const;
Expand Down
12 changes: 10 additions & 2 deletions presto-native-execution/presto_cpp/main/types/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@ add_library(
add_dependencies(presto_types presto_operators presto_type_converter velox_type
velox_type_fbhive)

target_link_libraries(presto_types presto_type_converter velox_type_fbhive
velox_hive_partition_function velox_tpch_gen velox_functions_json)
target_link_libraries(
presto_types presto_type_converter velox_type_fbhive
velox_hive_partition_function velox_tpch_gen velox_functions_json)

if(PRESTO_ENABLE_REMOTE_FUNCTIONS)
add_dependencies(presto_types velox_expression presto_server_remote_function
velox_functions_remote)
target_link_libraries(presto_types presto_server_remote_function
velox_functions_remote)
endif()

set_property(TARGET presto_types PROPERTY JOB_POOL_LINK presto_link_job_pool)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@
#include "velox/vector/ComplexVector.h"
#include "velox/vector/ConstantVector.h"
#include "velox/vector/FlatVector.h"
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
#include "presto_cpp/main/JsonSignatureParser.h"
#include "velox/expression/FunctionSignature.h"
#include "velox/functions/remote/client/Remote.h"
#endif

using namespace facebook::velox::core;
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
using facebook::velox::functions::remote::PageFormat;
#endif
using facebook::velox::TypeKind;

namespace facebook::presto {
Expand Down Expand Up @@ -127,6 +135,61 @@ std::string getFunctionName(const protocol::SqlFunctionId& functionId) {

} // namespace

#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
TypedExprPtr VeloxExprConverter::registerRestRemoteFunction(
const protocol::RestFunctionHandle& restFunctionHandle,
const std::vector<TypedExprPtr>& args,
const velox::TypePtr& returnType) const {
const auto* systemConfig = SystemConfig::instance();

velox::functions::RemoteVectorFunctionMetadata metadata;
const auto& serdeName = systemConfig->remoteFunctionServerSerde();
if (serdeName == "presto_page") {
metadata.serdeFormat = PageFormat::PRESTO_PAGE;
} else {
VELOX_FAIL(
"presto_page serde is expected by remote function server but got : '{}'",
serdeName);
}
metadata.location = systemConfig->remoteFunctionRestUrl();
metadata.functionId = restFunctionHandle.functionId;
metadata.version = restFunctionHandle.version;

const auto& prestoSignature = restFunctionHandle.signature;
velox::exec::FunctionSignatureBuilder signatureBuilder;

for (const auto& typeVar : prestoSignature.typeVariableConstraints) {
signatureBuilder.typeVariable(typeVar.name);
}

for (const auto& longVar : prestoSignature.longVariableConstraints) {
signatureBuilder.integerVariable(longVar.name);
}

signatureBuilder.returnType(prestoSignature.returnType);

for (const auto& argType : prestoSignature.argumentTypes) {
signatureBuilder.argumentType(argType);
}

if (prestoSignature.variableArity) {
signatureBuilder.variableArity();
}

auto signature = signatureBuilder.build();
std::vector<velox::exec::FunctionSignaturePtr> veloxSignatures = {signature};

velox::functions::registerRemoteFunction(
getFunctionName(restFunctionHandle.functionId),
veloxSignatures,
metadata,
false);

return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(restFunctionHandle.functionId));
}
#endif

velox::variant VeloxExprConverter::getConstantValue(
const velox::TypePtr& type,
const protocol::Block& block) const {
Expand Down Expand Up @@ -504,10 +567,22 @@ TypedExprPtr VeloxExprConverter::toVeloxExpr(
pexpr.functionHandle)) {
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

return std::make_shared<CallTypedExpr>(
returnType, args, getFunctionName(sqlFunctionHandle->functionId));
}
#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
else if (
auto restFunctionHandle =
std::dynamic_pointer_cast<protocol::RestFunctionHandle>(
pexpr.functionHandle)) {
// Defer to our new helper function for restFunctionHandle.
auto args = toVeloxExpr(pexpr.arguments);
auto returnType = typeParser_->parse(pexpr.returnType);

return registerRestRemoteFunction(*restFunctionHandle, args, returnType);
}
#endif
VELOX_FAIL("Unsupported function handle: {}", pexpr.functionHandle->_type);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class VeloxExprConverter {
std::optional<velox::core::TypedExprPtr> tryConvertDate(
const protocol::CallExpression& pexpr) const;

#ifdef PRESTO_ENABLE_REMOTE_FUNCTIONS
velox::core::TypedExprPtr registerRestRemoteFunction(
const protocol::RestFunctionHandle& restFunctionHandle,
const std::vector<velox::core::TypedExprPtr>& args,
const velox::TypePtr& returnType) const;
#endif

velox::memory::MemoryPool* const pool_;
TypeParser* const typeParser_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ void to_json(json& j, const std::shared_ptr<FunctionHandle>& p) {
j = *std::static_pointer_cast<SqlFunctionHandle>(p);
return;
}
if (type == "rest") {
j = *std::static_pointer_cast<RestFunctionHandle>(p);
return;
}

throw TypeError(type + " no abstract type FunctionHandle ");
}
Expand Down Expand Up @@ -149,6 +153,13 @@ void from_json(const json& j, std::shared_ptr<FunctionHandle>& p) {
p = std::static_pointer_cast<FunctionHandle>(k);
return;
}
if (type == "rest") {
std::shared_ptr<RestFunctionHandle> k =
std::make_shared<RestFunctionHandle>();
j.get_to(*k);
p = std::static_pointer_cast<FunctionHandle>(k);
return;
}

throw TypeError(type + " no abstract type FunctionHandle ");
}
Expand Down Expand Up @@ -8356,6 +8367,52 @@ void from_json(const json& j, RemoteTransactionHandle& p) {
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
RestFunctionHandle::RestFunctionHandle() noexcept {
_type = "rest";
}

void to_json(json& j, const RestFunctionHandle& p) {
j = json::object();
j["@type"] = "rest";
to_json_key(
j,
"functionId",
p.functionId,
"RestFunctionHandle",
"SqlFunctionId",
"functionId");
to_json_key(
j, "version", p.version, "RestFunctionHandle", "String", "version");
to_json_key(
j,
"signature",
p.signature,
"RestFunctionHandle",
"Signature",
"signature");
}

void from_json(const json& j, RestFunctionHandle& p) {
p._type = j["@type"];
from_json_key(
j,
"functionId",
p.functionId,
"RestFunctionHandle",
"SqlFunctionId",
"functionId");
from_json_key(
j, "version", p.version, "RestFunctionHandle", "String", "version");
from_json_key(
j,
"signature",
p.signature,
"RestFunctionHandle",
"Signature",
"signature");
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
RowNumberNode::RowNumberNode() noexcept {
_type = "com.facebook.presto.sql.planner.plan.RowNumberNode";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,17 @@ void to_json(json& j, const RemoteTransactionHandle& p);
void from_json(const json& j, RemoteTransactionHandle& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct RestFunctionHandle : public FunctionHandle {
SqlFunctionId functionId = {};
String version = {};
Signature signature = {};

RestFunctionHandle() noexcept;
};
void to_json(json& j, const RestFunctionHandle& p);
void from_json(const json& j, RestFunctionHandle& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct RowNumberNode : public PlanNode {
std::shared_ptr<PlanNode> source = {};
List<VariableReferenceExpression> partitionBy = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ AbstractClasses:
- { name: BuiltInFunctionHandle, key: $static }
- { name: SqlFunctionHandle, key: native }
- { name: SqlFunctionHandle, key: json_file }
- { name: RestFunctionHandle, key: rest }


JavaClasses:
Expand All @@ -192,6 +193,7 @@ JavaClasses:
- presto-main/src/main/java/com/facebook/presto/execution/buffer/BufferState.java
- presto-main/src/main/java/com/facebook/presto/metadata/BuiltInFunctionHandle.java
- presto-spi/src/main/java/com/facebook/presto/spi/function/SqlFunctionHandle.java
- presto-spi/src/main/java/com/facebook/presto/spi/function/RestFunctionHandle.java
- presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaRequirement.java
- presto-hdfs-core/src/main/java/com/facebook/presto/hive/CacheQuotaScope.java
- presto-spi/src/main/java/com/facebook/presto/spi/relation/CallExpression.java
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,19 @@
@ThriftStruct
public final class Signature
{
@JsonProperty("name")
private final QualifiedObjectName name;
@JsonProperty("kind")
private final FunctionKind kind;
@JsonProperty("typeVariableConstraints")
private final List<TypeVariableConstraint> typeVariableConstraints;
@JsonProperty("longVariableConstraints")
private final List<LongVariableConstraint> longVariableConstraints;
@JsonProperty("returnType")
private final TypeSignature returnType;
@JsonProperty("argumentTypes")
private final List<TypeSignature> argumentTypes;
@JsonProperty("variableArity")
private final boolean variableArity;

@ThriftConstructor
Expand Down

0 comments on commit 9929892

Please sign in to comment.