diff --git a/velox/functions/remote/CMakeLists.txt b/velox/functions/remote/CMakeLists.txt index ccc8a2c5ec483..c5f32ca662cb4 100644 --- a/velox/functions/remote/CMakeLists.txt +++ b/velox/functions/remote/CMakeLists.txt @@ -12,6 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +if(NOT DEFINED PROXYGEN_LIBRARIES) + find_package(Sodium REQUIRED) + + find_library(PROXYGEN proxygen) + find_library(PROXYGEN_HTTP_SERVER proxygenhttpserver) + find_library(FIZZ fizz) + find_library(WANGLE wangle) + + if(NOT PROXYGEN + OR NOT PROXYGEN_HTTP_SERVER + OR NOT FIZZ + OR NOT WANGLE) + message( + FATAL_ERROR + "One or more proxygen libraries were not found. Please ensure proxygen, proxygenhttpserver, fizz, and wangle are installed." + ) + endif() + + set(PROXYGEN_LIBRARIES ${PROXYGEN_HTTP_SERVER} ${PROXYGEN} ${WANGLE} ${FIZZ}) +endif() + add_subdirectory(if) add_subdirectory(client) add_subdirectory(server) diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 56663a29d04b8..28bdc31fb5351 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -16,10 +16,19 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp) velox_link_libraries(velox_functions_remote_thrift_client PUBLIC remote_function_thrift FBThrift::thriftcpp2) +velox_add_library(velox_functions_remote_rest_client RestClient.cpp) +velox_link_libraries(velox_functions_remote_rest_client ${PROXYGEN_LIBRARIES} + velox_exec Folly::folly) + velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( velox_functions_remote - PUBLIC velox_expression + PUBLIC velox_functions_remote_rest_client + velox_expression + velox_memory + velox_exec + velox_vector + velox_presto_serializer velox_functions_remote_thrift_client velox_functions_remote_get_serde velox_type_fbhive diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 8458b84baaef2..db68fceb91fb5 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -17,14 +17,25 @@ #include "velox/functions/remote/client/Remote.h" #include +#include "velox/common/memory/ByteStream.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/remote/client/RestClient.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/type/fbhive/HiveTypeSerializer.h" +#include "velox/vector/FlatVector.h" #include "velox/vector/VectorStream.h" +#include +#include +#include +#include + +using namespace folly; +using namespace proxygen; namespace facebook::velox::functions { namespace { @@ -33,17 +44,43 @@ std::string serializeType(const TypePtr& type) { return type::fbhive::HiveTypeSerializer::serialize(type); } +std::string extractFunctionName(const std::string& input) { + size_t lastDot = input.find_last_of('.'); + if (lastDot != std::string::npos) { + return input.substr(lastDot + 1); + } + return input; +} + +std::string urlEncode(const std::string& value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + for (char c : value) { + if (isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '~') { + escaped << c; + } else { + escaped << '%' << std::setw(2) << int(static_cast(c)); + } + } + return escaped.str(); +} + class RemoteFunction : public exec::VectorFunction { public: RemoteFunction( const std::string& functionName, const std::vector& inputArgs, const RemoteVectorFunctionMetadata& metadata) - : functionName_(functionName), - location_(metadata.location), - thriftClient_(getThriftClient(location_, &eventBase_)), - serdeFormat_(metadata.serdeFormat), - serde_(getSerde(serdeFormat_)) { + : functionName_(functionName), metadata_(metadata) { + if (metadata.location.type() == typeid(SocketAddress)) { + location_ = boost::get(metadata.location); + thriftClient_ = getThriftClient(location_, &eventBase_); + } else if (metadata.location.type() == typeid(URL)) { + url_ = boost::get(metadata.location); + } + std::vector types; types.reserve(inputArgs.size()); serializedInputTypes_.reserve(inputArgs.size()); @@ -62,7 +99,11 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const override { try { - applyRemote(rows, args, outputType, context, result); + if ((metadata_.location.type() == typeid(SocketAddress))) { + applyRemote(rows, args, outputType, context, result); + } else if (metadata_.location.type() == typeid(URL)) { + applyRestRemote(rows, args, outputType, context, result); + } } catch (const VeloxRuntimeError&) { throw; } catch (const std::exception&) { @@ -71,6 +112,55 @@ class RemoteFunction : public exec::VectorFunction { } private: + void applyRestRemote( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { + try { + std::string fullUrl = fmt::format( + "{}/v1/functions/{}/{}/{}/{}", + url_.getUrl(), + metadata_.schema.value_or("default"), + extractFunctionName(functionName_), + urlEncode(metadata_.functionId.value_or("default_function_id")), + metadata_.version.value_or("1")); + + // Serialize the input data + serializer::presto::PrestoVectorSerde serde; + serializer::presto::PrestoVectorSerde::PrestoOptions options; + + auto remoteRowVector = std::make_shared( + context.pool(), + remoteInputType_, + BufferPtr{}, + rows.end(), + std::move(args)); + + // Serialize the RowVector into an IOBuf (binary format) + IOBuf payload = rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), &serde); + + // Send the serialized data to the remote function via RestClient + RestClient restClient(fullUrl); + std::unique_ptr responseBody; + restClient.invoke_function( + std::make_unique(std::move(payload)), (responseBody)); + + auto outputRowVector = IOBufToRowVector( + *responseBody, ROW({outputType}), *context.pool(), &serde); + result = outputRowVector->childAt(0); + + } catch (const std::exception& e) { + // Catch and handle any exceptions thrown during the process + VELOX_FAIL( + "Error while executing remote function '{}': {}", + functionName_, + e.what()); + } + } + void applyRemote( const SelectivityVector& rows, std::vector& args, @@ -97,11 +187,14 @@ class RemoteFunction : public exec::VectorFunction { auto requestInputs = request.inputs_ref(); requestInputs->rowCount_ref() = remoteRowVector->size(); - requestInputs->pageFormat_ref() = serdeFormat_; + requestInputs->pageFormat_ref() = metadata_.serdeFormat; // TODO: serialize only active rows. requestInputs->payload_ref() = rowVectorToIOBuf( - remoteRowVector, rows.end(), *context.pool(), serde_.get()); + remoteRowVector, + rows.end(), + *context.pool(), + getSerde(metadata_.serdeFormat).get()); try { thriftClient_->sync_invokeFunction(remoteResponse, request); @@ -117,12 +210,15 @@ class RemoteFunction : public exec::VectorFunction { remoteResponse.get_result().get_payload(), ROW({outputType}), *context.pool(), - serde_.get()); + getSerde(metadata_.serdeFormat).get()); result = outputRowVector->childAt(0); if (auto errorPayload = remoteResponse.get_result().errorPayload()) { auto errorsRowVector = IOBufToRowVector( - *errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get()); + *errorPayload, + ROW({VARCHAR()}), + *context.pool(), + getSerde(metadata_.serdeFormat).get()); auto errorsVector = errorsRowVector->childAt(0)->asFlatVector(); VELOX_CHECK(errorsVector, "Should be convertible to flat vector"); @@ -142,16 +238,14 @@ class RemoteFunction : public exec::VectorFunction { } const std::string functionName_; - folly::SocketAddress location_; - folly::EventBase eventBase_; + EventBase eventBase_; std::unique_ptr thriftClient_; - remote::PageFormat serdeFormat_; - std::unique_ptr serde_; - - // Structures we construct once to cache: + SocketAddress location_; + URL url_; RowTypePtr remoteInputType_; std::vector serializedInputTypes_; + const RemoteVectorFunctionMetadata metadata_; }; std::shared_ptr createRemoteFunction( @@ -169,7 +263,7 @@ void registerRemoteFunction( std::vector signatures, const RemoteVectorFunctionMetadata& metadata, bool overwrite) { - exec::registerStatefulVectorFunction( + registerStatefulVectorFunction( name, signatures, std::bind( diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index a6a1e773dc812..88b5544c172be 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -16,20 +16,38 @@ #pragma once +#include #include +#include #include "velox/expression/VectorFunction.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" namespace facebook::velox::functions { struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { - /// Network address of the servr to communicate with. Note that this can hold - /// a network location (ip/port pair) or a unix domain socket path (see + /// URL of the HTTP/REST server for remote function. + /// Or Network address of the servr to communicate with. Note that this can + /// hold a network location (ip/port pair) or a unix domain socket path (see /// SocketAddress::makeFromPath()). - folly::SocketAddress location; + boost::variant location; - /// The serialization format to be used + /// The serialization format to be used when sending data to the remote. remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE}; + + /// Optional schema defining the structure of the data or input/output types + /// involved in the remote function. This may include details such as column + /// names and data types. + std::optional schema; + + /// Optional identifier for the specific remote function to be invoked. + /// This can be useful when the same server hosts multiple functions, + /// and the client needs to specify which function to call. + std::optional functionId; + + /// Optional version information to be used when calling the remote function. + /// This can help in ensuring compatibility with a particular version of the + /// function if multiple versions are available on the server. + std::optional version; }; /// Registers a new remote function. It will use the meatadata defined in @@ -38,8 +56,8 @@ struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { // /// Remote functions are registered as regular statufull functions (using the /// same internal catalog), and hence conflict if there already exists a -/// (non-remote) function registered with the same name. The `overwrite` flag -/// controls whether to overwrite in these cases. +/// (non-remote) function registered with the same name. The `overwrite` +/// flagwrite controls whether to overwrite in these cases. void registerRemoteFunction( const std::string& name, std::vector signatures, diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp new file mode 100644 index 0000000000000..7835dbd9c9d9b --- /dev/null +++ b/velox/functions/remote/client/RestClient.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RestClient.h" +#include +#include + +namespace facebook::velox::functions { + +// RestClient Implementation +RestClient::RestClient(const std::string& url) : url_(URL(url)) { + httpClient_ = std::make_shared(url_); +} + +void RestClient::invoke_function( + std::unique_ptr requestBody, + std::unique_ptr& responseBody) const { + httpClient_->send(std::move(requestBody)); + responseBody = httpClient_->getResponseBody(); +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h new file mode 100644 index 0000000000000..79772eb5d511d --- /dev/null +++ b/velox/functions/remote/client/RestClient.h @@ -0,0 +1,133 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include "velox/functions/remote/client/RestClient.h" + +using namespace proxygen; +using namespace folly; + +namespace facebook::velox::functions { + +class HttpClient : public HTTPConnector::Callback, + public HTTPTransactionHandler { + public: + HttpClient(const URL& url) : url_(url) {} + + void send(std::unique_ptr requestBody) { + requestBody_ = std::move(requestBody); + connector_ = std::make_unique( + this, WheelTimerInstance(std::chrono::milliseconds(1000))); + connector_->connect( + &evb_, + SocketAddress(url_.getHost(), url_.getPort(), true), + std::chrono::milliseconds(10000)); + evb_.loop(); + } + + std::unique_ptr getResponseBody() { + return std::move(responseBody_); + } + + private: + URL url_; + EventBase evb_; + std::unique_ptr connector_; + std::shared_ptr session_; + std::unique_ptr requestBody_; + std::unique_ptr responseBody_; + + void connectSuccess(HTTPUpstreamSession* session) noexcept override { + session_ = std::shared_ptr( + session, [](HTTPUpstreamSession* s) { + // No-op deleter, managed by Proxygen + }); + sendRequest(); + } + + void connectError(const AsyncSocketException& ex) noexcept override { + LOG(ERROR) << "Failed to connect: " << ex.what(); + evb_.terminateLoopSoon(); + } + + void sendRequest() { + auto txn = session_->newTransaction(this); + HTTPMessage req; + req.setMethod(HTTPMethod::POST); + req.setURL(url_.getUrl()); + + req.getHeaders().add( + "Content-Length", + std::to_string(requestBody_->computeChainDataLength())); + req.getHeaders().add("Content-Type", "application/X-presto-pages"); + req.getHeaders().add("Accept", "application/X-presto-pages"); + + txn->sendHeaders(req); + txn->sendBody(std::move(requestBody_)); + txn->sendEOM(); + } + + void setTransaction(HTTPTransaction*) noexcept override {} + + void detachTransaction() noexcept override { + session_.reset(); + evb_.terminateLoopSoon(); + } + + void onHeadersComplete(std::unique_ptr msg) noexcept override { + LOG(INFO) << "Received headers"; + } + + void onBody(std::unique_ptr chain) noexcept override { + responseBody_ = std::move(chain); + } + + void onEOM() noexcept override { + LOG(INFO) << "Transaction complete"; + session_->drain(); + } + + void onError(const HTTPException& error) noexcept override { + LOG(ERROR) << "Error: " << error.what(); + } + + void onUpgrade(UpgradeProtocol) noexcept override {} + void onTrailers(std::unique_ptr) noexcept override {} + void onEgressPaused() noexcept override {} + void onEgressResumed() noexcept override {} +}; + +class RestClient { + public: + RestClient(const std::string& url); + + void invoke_function( + std::unique_ptr request, + std::unique_ptr& response) const; + + private: + URL url_; + std::shared_ptr httpClient_; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/CMakeLists.txt b/velox/functions/remote/client/tests/CMakeLists.txt index 1659ad9d7e5a3..38d0b25dbbd73 100644 --- a/velox/functions/remote/client/tests/CMakeLists.txt +++ b/velox/functions/remote/client/tests/CMakeLists.txt @@ -27,3 +27,20 @@ target_link_libraries( GTest::gmock GTest::gtest GTest::gtest_main) + +add_executable(velox_functions_remote_client_rest_test + RemoteFunctionRestTest.cpp) + +add_test(velox_functions_remote_client_rest_test + velox_functions_remote_client_rest_test) + +target_link_libraries( + velox_functions_remote_client_rest_test + velox_functions_remote_server_rest + velox_functions_remote + velox_function_registry + velox_functions_test_lib + velox_exec_test_lib + GTest::gmock + GTest::gtest + GTest::gtest_main) diff --git a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp new file mode 100644 index 0000000000000..96c71fdd550a5 --- /dev/null +++ b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/lib/CheckedArithmetic.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +using ::facebook::velox::test::assertEqualVectors; + +namespace facebook::velox::functions { +namespace { + +class RemoteFunctionRestTest + : public test::FunctionBaseTest, + public testing::WithParamInterface { + public: + void SetUp() override { + initializeServer(); + registerRemoteFunctions(); + } + + // Registers a few remote functions to be used in this test. + void registerRemoteFunctions() const { + RemoteVectorFunctionMetadata metadata; + metadata.serdeFormat = GetParam(); + metadata.location = location_; + + auto absSignature = {exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("integer") + .build()}; + registerRemoteFunction("remote_abs", absSignature, metadata); + + auto plusSignatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerRemoteFunction("remote_plus", plusSignatures, metadata); + + RemoteVectorFunctionMetadata wrongMetadata = metadata; + wrongMetadata.location = folly::SocketAddress(); // empty address. + registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); + + auto divSignatures = {exec::FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build()}; + registerRemoteFunction("remote_divide", divSignatures, metadata); + + auto substrSignatures = {exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build()}; + registerRemoteFunction("remote_substr", substrSignatures, metadata); + + // Registers the actual function under a different prefix. This is only + // needed for tests since the http service runs in the same process. + registerFunction( + {remotePrefix_ + ".remote_abs"}); + registerFunction( + {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_divide"}); + registerFunction( + {remotePrefix_ + ".remote_substr"}); + } + + void initializeServer() { + HTTPServerOptions options; + options.idleTimeout = std::chrono::milliseconds(6000); + options.handlerFactories = + RequestHandlerChain() + .addThen(remotePrefix_) + .build(); + options.h2cEnabled = true; + + std::vector IPs = { + {folly::SocketAddress(location_.getHost(), location_.getPort(), true), + HTTPServer::Protocol::HTTP}}; + + server_ = std::make_shared(std::move(options)); + server_->bind(IPs); + + thread_ = std::make_unique([&] { server_->start(); }); + + VELOX_CHECK(waitForRunning(), "Unable to initialize HTTP server."); + LOG(INFO) << "HTTP server is up and running in local port " + << location_.getUrl(); + } + + ~RemoteFunctionRestTest() override { + server_->stop(); + thread_->join(); + LOG(INFO) << "HTTP server stopped."; + } + + private: + bool waitForRunning() const { + for (size_t i = 0; i < 100; ++i) { + using boost::asio::ip::tcp; + boost::asio::io_context io_context; + + tcp::socket socket(io_context); + tcp::resolver resolver(io_context); + + try { + boost::asio::connect( + socket, + resolver.resolve( + location_.getHost(), std::to_string(location_.getPort()))); + return true; + } catch (std::exception& e) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + return false; + } + + std::shared_ptr server_; + std::unique_ptr thread_; + + URL location_{URL("http://127.0.0.1:83211")}; + const std::string remotePrefix_{"remote"}; +}; + +TEST_P(RemoteFunctionRestTest, absolute) { + auto inputVector = makeFlatVector({-10, -20}); + auto results = evaluate>( + "remote_abs(c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({10, 20}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, simple) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto results = evaluate>( + "remote_plus(c0, c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({2, 4, 6, 8, 10}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, string) { + auto inputVector = + makeFlatVector({"hello", "my", "remote", "world"}); + auto inputVector1 = makeFlatVector({2, 1, 3, 5}); + auto results = evaluate>( + "remote_substr(c0, c1)", makeRowVector({inputVector, inputVector1})); + + auto expected = makeFlatVector({"ello", "my", "mote", "d"}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, connectionError) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto func = [&]() { + evaluate>( + "remote_wrong_port(c0, c0)", makeRowVector({inputVector})); + }; + + // Check it throw and that the exception has the "connection refused" + // substring. + EXPECT_THROW(func(), VeloxRuntimeError); + try { + func(); + } catch (const VeloxRuntimeError& e) { + EXPECT_THAT(e.message(), testing::HasSubstr("Channel is !good()")); + } +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + RemoteFunctionRestTestFixture, + RemoteFunctionRestTest, + ::testing::Values(remote::PageFormat::PRESTO_PAGE)); + +} // namespace +} // namespace facebook::velox::functions + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/functions/remote/server/CMakeLists.txt b/velox/functions/remote/server/CMakeLists.txt index ff2afa0fed6a8..e6dac7c977994 100644 --- a/velox/functions/remote/server/CMakeLists.txt +++ b/velox/functions/remote/server/CMakeLists.txt @@ -24,3 +24,19 @@ add_executable(velox_functions_remote_server_main RemoteFunctionServiceMain.cpp) target_link_libraries( velox_functions_remote_server_main velox_functions_remote_server velox_functions_prestosql) + +add_library(velox_functions_remote_server_rest RemoteFunctionRestService.cpp) +target_link_libraries( + velox_functions_remote_server_rest + ${PROXYGEN_LIBRARIES} + velox_type_fbhive + velox_memory + velox_functions_prestosql + velox_presto_serializer) + +add_executable(velox_functions_remote_server_rest_main + RemoteFunctionServiceRestMain.cpp) + +target_link_libraries( + velox_functions_remote_server_rest_main velox_functions_remote_server_rest + velox_functions_prestosql) diff --git a/velox/functions/remote/server/RemoteFunctionRestService.cpp b/velox/functions/remote/server/RemoteFunctionRestService.cpp new file mode 100644 index 0000000000000..77a9cda840c5e --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.cpp @@ -0,0 +1,209 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/server/RemoteFunctionRestService.h" +#include +#include +#include + +#include "velox/expression/Expr.h" +#include "velox/type/fbhive/HiveTypeParser.h" +#include "velox/vector/VectorStream.h" + +namespace facebook::velox::functions { + +namespace { +struct InternalFunctionSignature { + std::vector argumentTypes; + std::string returnType; +}; + +// Initialize a map with function details +std::map internalFunctionSignatureMap = + { + {"remote_abs", {{"integer"}, "integer"}}, + {"remote_plus", {{"bigint", "bigint"}, "bigint"}}, + {"remote_divide", {{"double", "double"}, "double"}}, + {"remote_substr", {{"varchar", "integer"}, "varchar"}}, + // Add more functions here as needed +}; + +TypePtr deserializeType(const std::string& input) { + // Use hive type parser/serializer. + return type::fbhive::HiveTypeParser().parse(input); +} + +RowTypePtr deserializeArgTypes(const std::vector& argTypes) { + const size_t argCount = argTypes.size(); + + std::vector argumentTypes; + std::vector typeNames; + argumentTypes.reserve(argCount); + typeNames.reserve(argCount); + + for (size_t i = 0; i < argCount; ++i) { + argumentTypes.emplace_back(deserializeType(argTypes[i])); + typeNames.emplace_back(fmt::format("c{}", i)); + } + return ROW(std::move(typeNames), std::move(argumentTypes)); +} + +std::string getFunctionName( + const std::string& prefix, + const std::string& functionName) { + return prefix.empty() ? functionName + : fmt::format("{}.{}", prefix, functionName); +} +} // namespace + +std::vector getExpressions( + const RowTypePtr& inputType, + const TypePtr& returnType, + const std::string& functionName) { + std::vector inputs; + for (size_t i = 0; i < inputType->size(); ++i) { + inputs.push_back(std::make_shared( + inputType->childAt(i), inputType->nameOf(i))); + } + + return {std::make_shared( + returnType, std::move(inputs), functionName)}; +} + +void RestRequestHandler::onRequest( + std::unique_ptr headers) noexcept { + const std::string& path = headers->getURL(); + + // Split the path by '/' + std::vector pathComponents; + folly::split('/', path, pathComponents); + + // Check if the path has enough components + if (pathComponents.size() >= 5) { + // Extract the functionName from the path + // Assuming the functionName is the 5th component + functionName_ = pathComponents[6]; + } +} + +void RestRequestHandler::onEOM() noexcept { + try { + const auto& functionSignature = + internalFunctionSignatureMap.at(functionName_); + + auto inputType = deserializeArgTypes(functionSignature.argumentTypes); + auto returnType = deserializeType(functionSignature.returnType); + + serializer::presto::PrestoVectorSerde serde; + auto inputVector = IOBufToRowVector(*body_, inputType, *pool_, &serde); + + const vector_size_t numRows = inputVector->size(); + SelectivityVector rows{numRows}; + + // Expression boilerplate. + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx{pool_.get(), queryCtx.get()}; + exec::ExprSet exprSet{ + getExpressions( + inputType, + returnType, + getFunctionName(functionPrefix_, functionName_)), + &execCtx}; + exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); + + std::vector expressionResult; + exprSet.eval(rows, evalCtx, expressionResult); + + // Create output vector. + auto outputRowVector = std::make_shared( + pool_.get(), ROW({returnType}), BufferPtr(), numRows, expressionResult); + + auto payload = + rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, &serde); + + ResponseBuilder(downstream_) + .status(200, "OK") + .body(std::make_unique(payload)) + .sendWithEOM(); + + } catch (const std::exception& ex) { + LOG(ERROR) << ex.what(); + ResponseBuilder(downstream_) + .status(500, "Internal Server Error") + .body(folly::IOBuf::copyBuffer(ex.what())) + .sendWithEOM(); + } +} + +void RestRequestHandler::onBody(std::unique_ptr body) noexcept { + if (body) { + body_ = std::move(body); + } +} + +void RestRequestHandler::onUpgrade(UpgradeProtocol /*protocol*/) noexcept { + // handler doesn't support upgrades +} + +void RestRequestHandler::requestComplete() noexcept { + delete this; +} + +void RestRequestHandler::onError(ProxygenError /*err*/) noexcept { + delete this; +} + +// ErrorHandler +ErrorHandler::ErrorHandler(int statusCode, std::string message) + : statusCode_(statusCode), message_(std::move(message)) {} + +void ErrorHandler::onRequest(std::unique_ptr) noexcept { + ResponseBuilder(downstream_) + .status(statusCode_, "Error") + .body(std::move(message_)) + .sendWithEOM(); +} + +void ErrorHandler::onEOM() noexcept {} + +void ErrorHandler::onBody(std::unique_ptr body) noexcept {} + +void ErrorHandler::onUpgrade(UpgradeProtocol protocol) noexcept { + // handler doesn't support upgrades +} + +void ErrorHandler::requestComplete() noexcept { + delete this; +} + +void ErrorHandler::onError(ProxygenError err) noexcept { + delete this; +} + +// RestRequestHandlerFactory +void RestRequestHandlerFactory::onServerStart(folly::EventBase* evb) noexcept {} + +void RestRequestHandlerFactory::onServerStop() noexcept {} + +RequestHandler* RestRequestHandlerFactory::onRequest( + proxygen::RequestHandler*, + proxygen::HTTPMessage* msg) noexcept { + if (msg->getMethod() != HTTPMethod::POST) { + return new ErrorHandler(405, "Only POST method is allowed"); + } + return new RestRequestHandler(functionPrefix_); +} +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.h b/velox/functions/remote/server/RemoteFunctionRestService.h new file mode 100644 index 0000000000000..358965f753c0d --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.h @@ -0,0 +1,71 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include "velox/common/memory/Memory.h" + +using namespace proxygen; + +namespace facebook::velox::functions { +class ErrorHandler : public RequestHandler { + public: + explicit ErrorHandler(int statusCode, std::string message); + void onRequest(std::unique_ptr headers) noexcept override; + void onBody(std::unique_ptr) noexcept override; + void onEOM() noexcept override; + void onUpgrade(UpgradeProtocol protocol) noexcept override; + void requestComplete() noexcept override; + void onError(ProxygenError err) noexcept override; + + private: + int statusCode_; + std::string message_; +}; + +class RestRequestHandler : public RequestHandler { + public: + explicit RestRequestHandler(const std::string& functionPrefix = "") + : functionPrefix_(functionPrefix) {} + void onRequest(std::unique_ptr headers) noexcept override; + void onBody(std::unique_ptr body) noexcept override; + void onEOM() noexcept override; + void onUpgrade(UpgradeProtocol protocol) noexcept override; + void requestComplete() noexcept override; + void onError(ProxygenError err) noexcept override; + + private: + std::unique_ptr body_; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; + const std::string functionPrefix_; + std::string functionName_; +}; + +class RestRequestHandlerFactory : public RequestHandlerFactory { + public: + explicit RestRequestHandlerFactory(const std::string& functionPrefix = "") + : functionPrefix_(functionPrefix) {} + void onServerStart(folly::EventBase* evb) noexcept override; + void onServerStop() noexcept override; + RequestHandler* onRequest(RequestHandler*, HTTPMessage* msg) noexcept + override; + + private: + const std::string functionPrefix_; +}; +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp index c92ab9231d114..92ff2791bb1f8 100644 --- a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp +++ b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/remote/server/RemoteFunctionService.h" @@ -36,7 +37,7 @@ DEFINE_string( DEFINE_string( function_prefix, - "json.test_schema.", + "remote.schema.", "Prefix to be added to the functions being registered"); using namespace ::facebook::velox; @@ -46,11 +47,14 @@ int main(int argc, char* argv[]) { folly::Init init{&argc, &argv, false}; FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + // Always registers all Presto functions and make them available under a // certain prefix/namespace. LOG(INFO) << "Registering Presto functions"; functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + std::remove(FLAGS_uds_path.c_str()); folly::SocketAddress location{ folly::SocketAddress::makeFromPath(FLAGS_uds_path)}; diff --git a/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp new file mode 100644 index 0000000000000..4444c8c2d5645 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "velox/common/memory/Memory.h" + +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +DEFINE_string( + service_host, + "127.0.0.1", + "Prefix to be added to the functions being registered"); + +DEFINE_int32( + service_port, + 8321, + "Prefix to be added to the functions being registered"); + +DEFINE_string( + function_prefix, + "remote.schema.", + "Prefix to be added to the functions being registered"); + +using namespace ::facebook::velox; + +int main(int argc, char* argv[]) { + folly::Init init(&argc, &argv); + FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + + // A remote function service should handle the function execution by its own. + // But we use Velox framework for quick prototype here + functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + // registerFunction( + // {"remote_plus"}); + // End of function registration + + LOG(INFO) << "Start HTTP Server at " << "http://" << FLAGS_service_host << ":" + << FLAGS_service_port; + + HTTPServerOptions options; + options.idleTimeout = std::chrono::milliseconds(60000); + options.handlerFactories = + RequestHandlerChain() + .addThen() + .build(); + options.h2cEnabled = true; + + std::vector IPs = { + {folly::SocketAddress(FLAGS_service_host, FLAGS_service_port, true), + HTTPServer::Protocol::HTTP}}; + + HTTPServer server(std::move(options)); + server.bind(IPs); + + std::thread t([&]() { server.start(); }); + + t.join(); + return 0; +}