diff --git a/CMakeLists.txt b/CMakeLists.txt index c41075394e495..69ab08b87a37c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -114,7 +114,7 @@ option(VELOX_ENABLE_ABFS "Build Abfs Connector" OFF) option(VELOX_ENABLE_HDFS "Build Hdfs Connector" OFF) option(VELOX_ENABLE_PARQUET "Enable Parquet support" OFF) option(VELOX_ENABLE_ARROW "Enable Arrow support" OFF) -option(VELOX_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) +option(VELOX_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" ON) option(VELOX_ENABLE_CCACHE "Use ccache if installed." ON) option(VELOX_BUILD_TEST_UTILS "Builds Velox test utilities" OFF) diff --git a/velox/functions/remote/CMakeLists.txt b/velox/functions/remote/CMakeLists.txt index c5f32ca662cb4..ccc8a2c5ec483 100644 --- a/velox/functions/remote/CMakeLists.txt +++ b/velox/functions/remote/CMakeLists.txt @@ -12,27 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -if(NOT DEFINED PROXYGEN_LIBRARIES) - find_package(Sodium REQUIRED) - - find_library(PROXYGEN proxygen) - find_library(PROXYGEN_HTTP_SERVER proxygenhttpserver) - find_library(FIZZ fizz) - find_library(WANGLE wangle) - - if(NOT PROXYGEN - OR NOT PROXYGEN_HTTP_SERVER - OR NOT FIZZ - OR NOT WANGLE) - message( - FATAL_ERROR - "One or more proxygen libraries were not found. Please ensure proxygen, proxygenhttpserver, fizz, and wangle are installed." - ) - endif() - - set(PROXYGEN_LIBRARIES ${PROXYGEN_HTTP_SERVER} ${PROXYGEN} ${WANGLE} ${FIZZ}) -endif() - add_subdirectory(if) add_subdirectory(client) add_subdirectory(server) diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 28bdc31fb5351..8e2f5802b8ea2 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -16,20 +16,20 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp) velox_link_libraries(velox_functions_remote_thrift_client PUBLIC remote_function_thrift FBThrift::thriftcpp2) -velox_add_library(velox_functions_remote_rest_client RestClient.cpp) -velox_link_libraries(velox_functions_remote_rest_client ${PROXYGEN_LIBRARIES} - velox_exec Folly::folly) +add_library(velox_functions_remote_rest_client RestClient.cpp) +target_link_libraries( + velox_functions_remote_rest_client Folly::folly CURL::libcurl) velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( velox_functions_remote - PUBLIC velox_functions_remote_rest_client - velox_expression + PUBLIC velox_expression velox_memory velox_exec velox_vector velox_presto_serializer velox_functions_remote_thrift_client + velox_functions_remote_rest_client velox_functions_remote_get_serde velox_type_fbhive Folly::folly) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index db68fceb91fb5..787462c7559cb 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -20,7 +20,6 @@ #include "velox/common/memory/ByteStream.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" -#include "velox/functions/remote/client/RestClient.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" @@ -29,18 +28,17 @@ #include "velox/vector/FlatVector.h" #include "velox/vector/VectorStream.h" -#include -#include +#include #include #include +#include "RestClient.h" + using namespace folly; -using namespace proxygen; namespace facebook::velox::functions { namespace { std::string serializeType(const TypePtr& type) { - // Use hive type serializer. return type::fbhive::HiveTypeSerializer::serialize(type); } @@ -72,13 +70,16 @@ class RemoteFunction : public exec::VectorFunction { RemoteFunction( const std::string& functionName, const std::vector& inputArgs, - const RemoteVectorFunctionMetadata& metadata) - : functionName_(functionName), metadata_(metadata) { + const RemoteVectorFunctionMetadata& metadata, + std::unique_ptr httpClient = nullptr) + : functionName_(functionName), + restClient_(httpClient ? std::move(httpClient) : getRestClient()), + metadata_(metadata) { if (metadata.location.type() == typeid(SocketAddress)) { location_ = boost::get(metadata.location); thriftClient_ = getThriftClient(location_, &eventBase_); - } else if (metadata.location.type() == typeid(URL)) { - url_ = boost::get(metadata.location); + } else if (metadata.location.type() == typeid(std::string)) { + url_ = boost::get(metadata.location); } std::vector types; @@ -101,7 +102,7 @@ class RemoteFunction : public exec::VectorFunction { try { if ((metadata_.location.type() == typeid(SocketAddress))) { applyRemote(rows, args, outputType, context, result); - } else if (metadata_.location.type() == typeid(URL)) { + } else if (metadata_.location.type() == typeid(std::string)) { applyRestRemote(rows, args, outputType, context, result); } } catch (const VeloxRuntimeError&) { @@ -119,18 +120,7 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const { try { - std::string fullUrl = fmt::format( - "{}/v1/functions/{}/{}/{}/{}", - url_.getUrl(), - metadata_.schema.value_or("default"), - extractFunctionName(functionName_), - urlEncode(metadata_.functionId.value_or("default_function_id")), - metadata_.version.value_or("1")); - - // Serialize the input data serializer::presto::PrestoVectorSerde serde; - serializer::presto::PrestoVectorSerde::PrestoOptions options; - auto remoteRowVector = std::make_shared( context.pool(), remoteInputType_, @@ -138,22 +128,26 @@ class RemoteFunction : public exec::VectorFunction { rows.end(), std::move(args)); - // Serialize the RowVector into an IOBuf (binary format) - IOBuf payload = rowVectorToIOBuf( - remoteRowVector, rows.end(), *context.pool(), &serde); + std::unique_ptr requestBody = + std::make_unique(rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), &serde)); - // Send the serialized data to the remote function via RestClient - RestClient restClient(fullUrl); - std::unique_ptr responseBody; - restClient.invoke_function( - std::make_unique(std::move(payload)), (responseBody)); + const std::string fullUrl = fmt::format( + "{}/v1/functions/{}/{}/{}/{}", + url_, + metadata_.schema.value_or("default"), + extractFunctionName(functionName_), + urlEncode(metadata_.functionId.value_or("default_function_id")), + metadata_.version.value_or("1")); + + std::unique_ptr responseBody = + restClient_->performCurlRequest(fullUrl, std::move(requestBody)); auto outputRowVector = IOBufToRowVector( *responseBody, ROW({outputType}), *context.pool(), &serde); - result = outputRowVector->childAt(0); + result = outputRowVector->childAt(0); } catch (const std::exception& e) { - // Catch and handle any exceptions thrown during the process VELOX_FAIL( "Error while executing remote function '{}': {}", functionName_, @@ -238,11 +232,11 @@ class RemoteFunction : public exec::VectorFunction { } const std::string functionName_; - EventBase eventBase_; std::unique_ptr thriftClient_; + std::unique_ptr restClient_; SocketAddress location_; - URL url_; + std::string url_; RowTypePtr remoteInputType_; std::vector serializedInputTypes_; const RemoteVectorFunctionMetadata metadata_; diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index 88b5544c172be..09ea16dc426e2 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -18,7 +18,6 @@ #include #include -#include #include "velox/expression/VectorFunction.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" @@ -29,7 +28,7 @@ struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { /// Or Network address of the servr to communicate with. Note that this can /// hold a network location (ip/port pair) or a unix domain socket path (see /// SocketAddress::makeFromPath()). - boost::variant location; + boost::variant location; /// The serialization format to be used when sending data to the remote. remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE}; diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp index 7835dbd9c9d9b..4fb5f8d908e52 100644 --- a/velox/functions/remote/client/RestClient.cpp +++ b/velox/functions/remote/client/RestClient.cpp @@ -13,22 +13,101 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RestClient.h" -#include -#include +#include "velox/functions/remote/client/RestClient.h" + +#include +#include +#include + +#include "velox/common/base/Exceptions.h" + +using namespace folly; namespace facebook::velox::functions { +namespace { +size_t readCallback(char* dest, size_t size, size_t nmemb, void* userp) { + auto* inputBufQueue = static_cast(userp); + size_t bufferSize = size * nmemb; + size_t totalCopied = 0; -// RestClient Implementation -RestClient::RestClient(const std::string& url) : url_(URL(url)) { - httpClient_ = std::make_shared(url_); + while (totalCopied < bufferSize && !inputBufQueue->empty()) { + auto buf = inputBufQueue->front(); + size_t remainingSize = bufferSize - totalCopied; + size_t copySize = std::min(remainingSize, buf->length()); + std::memcpy(dest + totalCopied, buf->data(), copySize); + totalCopied += copySize; + inputBufQueue->pop_front(); + } + + return totalCopied; +} +size_t writeCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { + auto* outputBuf = static_cast(userdata); + size_t total_size = size * nmemb; + auto buf = IOBuf::copyBuffer(ptr, total_size); + outputBuf->append(std::move(buf)); + return total_size; } +} // namespace + +std::unique_ptr RestClient::performCurlRequest( + const std::string& fullUrl, + std::unique_ptr requestPayload) { + try { + IOBufQueue inputBufQueue(IOBufQueue::cacheChainLength()); + inputBufQueue.append(std::move(requestPayload)); + + CURL* curl = curl_easy_init(); + if (!curl) { + VELOX_FAIL(fmt::format( + "Error initializing CURL: {}", + curl_easy_strerror(CURLE_FAILED_INIT))); + } + + curl_easy_setopt(curl, CURLOPT_URL, fullUrl.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_READFUNCTION, readCallback); + curl_easy_setopt(curl, CURLOPT_READDATA, &inputBufQueue); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback); + + IOBufQueue outputBuf(IOBufQueue::cacheChainLength()); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &outputBuf); + curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); -void RestClient::invoke_function( - std::unique_ptr requestBody, - std::unique_ptr& responseBody) const { - httpClient_->send(std::move(requestBody)); - responseBody = httpClient_->getResponseBody(); -}; + struct curl_slist* headers = nullptr; + headers = + curl_slist_append(headers, "Content-Type: application/X-presto-pages"); + headers = curl_slist_append(headers, "Accept: application/X-presto-pages"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + curl_easy_setopt( + curl, + CURLOPT_POSTFIELDSIZE, + static_cast(inputBufQueue.chainLength())); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + VELOX_FAIL(fmt::format( + "Error communicating with server: {}\nURL: {}\nCURL Error: {}", + curl_easy_strerror(res), + fullUrl.c_str(), + curl_easy_strerror(res))); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return outputBuf.move(); + + } catch (const std::exception& e) { + VELOX_FAIL(fmt::format("Exception during CURL request: {}", e.what())); + } +} + +std::unique_ptr getRestClient() { + return std::make_unique(); +} } // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h index 79772eb5d511d..243ff99e0923b 100644 --- a/velox/functions/remote/client/RestClient.h +++ b/velox/functions/remote/client/RestClient.h @@ -9,125 +9,36 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, + * either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once -#include -#include -#include -#include -#include -#include -#include -#include "velox/functions/remote/client/RestClient.h" +#pragma once -using namespace proxygen; -using namespace folly; +#include +#include +#include namespace facebook::velox::functions { -class HttpClient : public HTTPConnector::Callback, - public HTTPTransactionHandler { +class HttpClient { public: - HttpClient(const URL& url) : url_(url) {} - - void send(std::unique_ptr requestBody) { - requestBody_ = std::move(requestBody); - connector_ = std::make_unique( - this, WheelTimerInstance(std::chrono::milliseconds(1000))); - connector_->connect( - &evb_, - SocketAddress(url_.getHost(), url_.getPort(), true), - std::chrono::milliseconds(10000)); - evb_.loop(); - } - - std::unique_ptr getResponseBody() { - return std::move(responseBody_); - } - - private: - URL url_; - EventBase evb_; - std::unique_ptr connector_; - std::shared_ptr session_; - std::unique_ptr requestBody_; - std::unique_ptr responseBody_; - - void connectSuccess(HTTPUpstreamSession* session) noexcept override { - session_ = std::shared_ptr( - session, [](HTTPUpstreamSession* s) { - // No-op deleter, managed by Proxygen - }); - sendRequest(); - } - - void connectError(const AsyncSocketException& ex) noexcept override { - LOG(ERROR) << "Failed to connect: " << ex.what(); - evb_.terminateLoopSoon(); - } - - void sendRequest() { - auto txn = session_->newTransaction(this); - HTTPMessage req; - req.setMethod(HTTPMethod::POST); - req.setURL(url_.getUrl()); - - req.getHeaders().add( - "Content-Length", - std::to_string(requestBody_->computeChainDataLength())); - req.getHeaders().add("Content-Type", "application/X-presto-pages"); - req.getHeaders().add("Accept", "application/X-presto-pages"); - - txn->sendHeaders(req); - txn->sendBody(std::move(requestBody_)); - txn->sendEOM(); - } + virtual ~HttpClient() = default; - void setTransaction(HTTPTransaction*) noexcept override {} - - void detachTransaction() noexcept override { - session_.reset(); - evb_.terminateLoopSoon(); - } - - void onHeadersComplete(std::unique_ptr msg) noexcept override { - LOG(INFO) << "Received headers"; - } - - void onBody(std::unique_ptr chain) noexcept override { - responseBody_ = std::move(chain); - } - - void onEOM() noexcept override { - LOG(INFO) << "Transaction complete"; - session_->drain(); - } - - void onError(const HTTPException& error) noexcept override { - LOG(ERROR) << "Error: " << error.what(); - } - - void onUpgrade(UpgradeProtocol) noexcept override {} - void onTrailers(std::unique_ptr) noexcept override {} - void onEgressPaused() noexcept override {} - void onEgressResumed() noexcept override {} + virtual std::unique_ptr performCurlRequest( + const std::string& url, + std::unique_ptr requestPayload) = 0; }; -class RestClient { +class RestClient : public HttpClient { public: - RestClient(const std::string& url); - - void invoke_function( - std::unique_ptr request, - std::unique_ptr& response) const; - - private: - URL url_; - std::shared_ptr httpClient_; + std::unique_ptr performCurlRequest( + const std::string& url, + std::unique_ptr requestPayload) override; }; +std::unique_ptr getRestClient(); + } // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/CMakeLists.txt b/velox/functions/remote/client/tests/CMakeLists.txt index 38d0b25dbbd73..d31d143498891 100644 --- a/velox/functions/remote/client/tests/CMakeLists.txt +++ b/velox/functions/remote/client/tests/CMakeLists.txt @@ -38,7 +38,6 @@ target_link_libraries( velox_functions_remote_client_rest_test velox_functions_remote_server_rest velox_functions_remote - velox_function_registry velox_functions_test_lib velox_exec_test_lib GTest::gmock diff --git a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp index 96c71fdd550a5..6ed004d4f6075 100644 --- a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp +++ b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp @@ -15,11 +15,9 @@ */ #include -#include #include #include #include -#include #include "velox/common/base/Exceptions.h" #include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/Registerer.h" @@ -50,39 +48,38 @@ class RemoteFunctionRestTest metadata.serdeFormat = GetParam(); metadata.location = location_; - auto absSignature = {exec::FunctionSignatureBuilder() - .returnType("integer") - .argumentType("integer") - .build()}; - registerRemoteFunction("remote_abs", absSignature, metadata); + auto absSignature = exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("integer") + .build(); + registerRemoteFunction("remote_abs", {absSignature}, metadata); - auto plusSignatures = {exec::FunctionSignatureBuilder() - .returnType("bigint") - .argumentType("bigint") - .argumentType("bigint") - .build()}; - registerRemoteFunction("remote_plus", plusSignatures, metadata); + auto plusSignature = exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build(); + registerRemoteFunction("remote_plus", {plusSignature}, metadata); RemoteVectorFunctionMetadata wrongMetadata = metadata; - wrongMetadata.location = folly::SocketAddress(); // empty address. - registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); - - auto divSignatures = {exec::FunctionSignatureBuilder() - .returnType("double") - .argumentType("double") - .argumentType("double") - .build()}; - registerRemoteFunction("remote_divide", divSignatures, metadata); - - auto substrSignatures = {exec::FunctionSignatureBuilder() - .returnType("varchar") - .argumentType("varchar") - .argumentType("integer") - .build()}; - registerRemoteFunction("remote_substr", substrSignatures, metadata); - - // Registers the actual function under a different prefix. This is only - // needed for tests since the http service runs in the same process. + wrongMetadata.location = ""; + registerRemoteFunction("remote_wrong_port", {plusSignature}, wrongMetadata); + + auto divSignature = exec::FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build(); + registerRemoteFunction("remote_divide", {divSignature}, metadata); + + auto substrSignature = exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build(); + registerRemoteFunction("remote_substr", {substrSignature}, metadata); + + // Registers the actual functions under a different prefix. registerFunction( {remotePrefix_ + ".remote_abs"}); registerFunction( @@ -94,31 +91,26 @@ class RemoteFunctionRestTest } void initializeServer() { - HTTPServerOptions options; - options.idleTimeout = std::chrono::milliseconds(6000); - options.handlerFactories = - RequestHandlerChain() - .addThen(remotePrefix_) - .build(); - options.h2cEnabled = true; - - std::vector IPs = { - {folly::SocketAddress(location_.getHost(), location_.getPort(), true), - HTTPServer::Protocol::HTTP}}; - - server_ = std::make_shared(std::move(options)); - server_->bind(IPs); - - thread_ = std::make_unique([&] { server_->start(); }); + // Start the server in a separate thread + server_ = std::make_shared("127.0.0.1", 8321, remotePrefix_); + serverThread_ = std::make_unique([this]() { + try { + server_->run(); + } catch (const std::exception& ex) { + LOG(ERROR) << "Server exception: " << ex.what(); + } + }); VELOX_CHECK(waitForRunning(), "Unable to initialize HTTP server."); - LOG(INFO) << "HTTP server is up and running in local port " - << location_.getUrl(); + LOG(INFO) << "HTTP server is up and running at " << location_; } ~RemoteFunctionRestTest() override { - server_->stop(); - thread_->join(); + // Stop the server and join the thread + // server_->stop(); + if (serverThread_ && serverThread_->joinable()) { + serverThread_->join(); + } LOG(INFO) << "HTTP server stopped."; } @@ -133,21 +125,20 @@ class RemoteFunctionRestTest try { boost::asio::connect( - socket, - resolver.resolve( - location_.getHost(), std::to_string(location_.getPort()))); + socket, resolver.resolve("127.0.0.1", std::to_string(8321))); return true; } catch (std::exception& e) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } } return false; } - std::shared_ptr server_; - std::unique_ptr thread_; + std::shared_ptr server_; + std::unique_ptr serverThread_; - URL location_{URL("http://127.0.0.1:83211")}; + const std::string location_{ + "127.0.0.1:8321"}; // Update to match the server's address and port const std::string remotePrefix_{"remote"}; }; @@ -187,13 +178,14 @@ TEST_P(RemoteFunctionRestTest, connectionError) { "remote_wrong_port(c0, c0)", makeRowVector({inputVector})); }; - // Check it throw and that the exception has the "connection refused" - // substring. + // Check it throws and that the exception message contains "Error + // communicating with server" EXPECT_THROW(func(), VeloxRuntimeError); try { func(); } catch (const VeloxRuntimeError& e) { - EXPECT_THAT(e.message(), testing::HasSubstr("Channel is !good()")); + EXPECT_THAT( + e.message(), testing::HasSubstr("Error communicating with server")); } } diff --git a/velox/functions/remote/server/CMakeLists.txt b/velox/functions/remote/server/CMakeLists.txt index e6dac7c977994..785f1895f2ce3 100644 --- a/velox/functions/remote/server/CMakeLists.txt +++ b/velox/functions/remote/server/CMakeLists.txt @@ -1,10 +1,3 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -28,14 +21,13 @@ target_link_libraries( add_library(velox_functions_remote_server_rest RemoteFunctionRestService.cpp) target_link_libraries( velox_functions_remote_server_rest - ${PROXYGEN_LIBRARIES} + velox_functions_remote_get_serde velox_type_fbhive velox_memory - velox_functions_prestosql - velox_presto_serializer) + velox_functions_prestosql) add_executable(velox_functions_remote_server_rest_main - RemoteFunctionServiceRestMain.cpp) + RemoteFunctionServiceRestMain.cpp) target_link_libraries( velox_functions_remote_server_rest_main velox_functions_remote_server_rest diff --git a/velox/functions/remote/server/RemoteFunctionRestService.cpp b/velox/functions/remote/server/RemoteFunctionRestService.cpp index 77a9cda840c5e..89d2c325d9700 100644 --- a/velox/functions/remote/server/RemoteFunctionRestService.cpp +++ b/velox/functions/remote/server/RemoteFunctionRestService.cpp @@ -13,19 +13,23 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include "velox/functions/remote/server/RemoteFunctionRestService.h" -#include -#include +#include "RemoteFunctionRestService.h" +#include +#include +#include +#include #include - #include "velox/expression/Expr.h" #include "velox/type/fbhive/HiveTypeParser.h" #include "velox/vector/VectorStream.h" namespace facebook::velox::functions { +namespace beast = boost::beast; +namespace http = beast::http; + namespace { + struct InternalFunctionSignature { std::vector argumentTypes; std::string returnType; @@ -61,14 +65,6 @@ RowTypePtr deserializeArgTypes(const std::vector& argTypes) { return ROW(std::move(typeNames), std::move(argumentTypes)); } -std::string getFunctionName( - const std::string& prefix, - const std::string& functionName) { - return prefix.empty() ? functionName - : fmt::format("{}.{}", prefix, functionName); -} -} // namespace - std::vector getExpressions( const RowTypePtr& inputType, const TypePtr& returnType, @@ -83,24 +79,57 @@ std::vector getExpressions( returnType, std::move(inputs), functionName)}; } -void RestRequestHandler::onRequest( - std::unique_ptr headers) noexcept { - const std::string& path = headers->getURL(); +std::string getFunctionName( + const std::string& prefix, + const std::string& functionName) { + return prefix.empty() ? functionName + : fmt::format("{}.{}", prefix, functionName); +} - // Split the path by '/' - std::vector pathComponents; - folly::split('/', path, pathComponents); +} // namespace - // Check if the path has enough components - if (pathComponents.size() >= 5) { - // Extract the functionName from the path - // Assuming the functionName is the 5th component - functionName_ = pathComponents[6]; - } -} +RestRequestHandler::RestRequestHandler(const std::string& functionPrefix) + : functionPrefix_(functionPrefix) {} + +void RestRequestHandler::handleRequest( + http::request>&& req, + std::function>&&)> send) { + http::response> res{ + http::status::ok, req.version()}; + res.set(http::field::server, BOOST_BEAST_VERSION_STRING); + res.set(http::field::content_type, "application/X-presto-pages"); + res.keep_alive(req.keep_alive()); -void RestRequestHandler::onEOM() noexcept { try { + // Parse the target URL + std::string target = req.target(); + + // Expected path format: + // /v1/functions/{schema}/{functionName}/{functionId}/{version} Split the + // path by '/' + std::vector pathComponents; + folly::split('/', target, pathComponents); + + // The first component is empty because the path starts with '/' + // So the components are: "", "v1", "functions", "{schema}", + // "{functionName}", "{functionId}", "{version}" + + if (pathComponents.size() >= 7 && pathComponents[1] == "v1" && + pathComponents[2] == "functions") { + schema_ = pathComponents[3]; + functionName_ = pathComponents[4]; + functionId_ = pathComponents[5]; + version_ = pathComponents[6]; + } else { + // Invalid path format + res.result(http::status::bad_request); + std::string errorMessage = "Invalid URL path"; + res.body().assign(errorMessage.begin(), errorMessage.end()); + res.prepare_payload(); + send(std::move(res)); + return; + } + const auto& functionSignature = internalFunctionSignatureMap.at(functionName_); @@ -108,7 +137,11 @@ void RestRequestHandler::onEOM() noexcept { auto returnType = deserializeType(functionSignature.returnType); serializer::presto::PrestoVectorSerde serde; - auto inputVector = IOBufToRowVector(*body_, inputType, *pool_, &serde); + + // Create an IOBuf from the request body + auto bodyData = + folly::IOBuf::wrapBuffer(req.body().data(), req.body().size()); + auto inputVector = IOBufToRowVector(*bodyData, inputType, *pool_, &serde); const vector_size_t numRows = inputVector->size(); SelectivityVector rows{numRows}; @@ -134,76 +167,161 @@ void RestRequestHandler::onEOM() noexcept { auto payload = rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, &serde); - ResponseBuilder(downstream_) - .status(200, "OK") - .body(std::make_unique(payload)) - .sendWithEOM(); - + // Set the response body + res.body().resize(payload.computeChainDataLength()); + auto payloadData = payload.coalesce(); + std::memcpy(res.body().data(), payloadData.data(), res.body().size()); + res.prepare_payload(); + send(std::move(res)); } catch (const std::exception& ex) { - LOG(ERROR) << ex.what(); - ResponseBuilder(downstream_) - .status(500, "Internal Server Error") - .body(folly::IOBuf::copyBuffer(ex.what())) - .sendWithEOM(); + // Handle exceptions and send error response + res.result(http::status::internal_server_error); + std::string errorMessage = ex.what(); + res.body().assign(errorMessage.begin(), errorMessage.end()); + res.prepare_payload(); + send(std::move(res)); } } -void RestRequestHandler::onBody(std::unique_ptr body) noexcept { - if (body) { - body_ = std::move(body); - } -} +RestSession::RestSession( + boost::asio::ip::tcp::socket socket, + const std::string& functionPrefix) + : stream_(std::move(socket)), handler_(functionPrefix) {} -void RestRequestHandler::onUpgrade(UpgradeProtocol /*protocol*/) noexcept { - // handler doesn't support upgrades +void RestSession::run() { + doRead(); } -void RestRequestHandler::requestComplete() noexcept { - delete this; -} +void RestSession::doRead() { + auto self = shared_from_this(); -void RestRequestHandler::onError(ProxygenError /*err*/) noexcept { - delete this; -} + req_ = {}; -// ErrorHandler -ErrorHandler::ErrorHandler(int statusCode, std::string message) - : statusCode_(statusCode), message_(std::move(message)) {} + stream_.expires_after(std::chrono::seconds(30)); -void ErrorHandler::onRequest(std::unique_ptr) noexcept { - ResponseBuilder(downstream_) - .status(statusCode_, "Error") - .body(std::move(message_)) - .sendWithEOM(); + http::async_read( + stream_, + buffer_, + req_, + [self](beast::error_code ec, std::size_t bytes_transferred) { + self->onRead(ec, bytes_transferred); + }); } -void ErrorHandler::onEOM() noexcept {} +void RestSession::onRead( + beast::error_code ec, + std::size_t /*bytes_transferred*/) { + if (ec == http::error::end_of_stream) { + return doClose(); + } -void ErrorHandler::onBody(std::unique_ptr body) noexcept {} + if (ec) { + LOG(ERROR) << "Read error: " << ec.message(); + return; + } -void ErrorHandler::onUpgrade(UpgradeProtocol protocol) noexcept { - // handler doesn't support upgrades + // Handle the request + handler_.handleRequest( + std::move(req_), + [self = shared_from_this()]( + http::response>&& res) { + // Write the response + self->stream_.expires_after(std::chrono::seconds(30)); + http::async_write( + self->stream_, + res, + [self, res]( + beast::error_code ec, std::size_t bytes_transferred) mutable { + self->onWrite(res.need_eof(), ec, bytes_transferred); + }); + }); } -void ErrorHandler::requestComplete() noexcept { - delete this; +void RestSession::onWrite( + bool close, + beast::error_code ec, + std::size_t /*bytes_transferred*/) { + if (ec) { + LOG(ERROR) << "Write error: " << ec.message(); + return; + } + + if (close) { + return doClose(); + } + + // Clear the buffer and start reading another request + buffer_.consume(buffer_.size()); + doRead(); } -void ErrorHandler::onError(ProxygenError err) noexcept { - delete this; +void RestSession::doClose() { + beast::error_code ec; + stream_.socket().shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec); } -// RestRequestHandlerFactory -void RestRequestHandlerFactory::onServerStart(folly::EventBase* evb) noexcept {} +RestServer::RestServer( + const std::string& address, + unsigned short port, + const std::string& functionPrefix) + : address_(address), + port_(port), + functionPrefix_(functionPrefix), + acceptor_(boost::asio::make_strand(ioc_)) {} + +void RestServer::run() { + boost::asio::ip::tcp::endpoint endpoint{ + boost::asio::ip::make_address(address_), port_}; + beast::error_code ec; + + // Open the acceptor + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + throw std::runtime_error("Failed to open acceptor: " + ec.message()); + } + + // Allow address reuse + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + throw std::runtime_error("Failed to set socket option: " + ec.message()); + } + + // Bind to the server address + acceptor_.bind(endpoint, ec); + if (ec) { + throw std::runtime_error("Failed to bind acceptor: " + ec.message()); + } + + // Start listening for connections + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + throw std::runtime_error("Failed to listen: " + ec.message()); + } + + doAccept(); + + // Run the I/O service + ioc_.run(); +} -void RestRequestHandlerFactory::onServerStop() noexcept {} +void RestServer::doAccept() { + acceptor_.async_accept( + boost::asio::make_strand(ioc_), + beast::bind_front_handler(&RestServer::onAccept, this)); +} -RequestHandler* RestRequestHandlerFactory::onRequest( - proxygen::RequestHandler*, - proxygen::HTTPMessage* msg) noexcept { - if (msg->getMethod() != HTTPMethod::POST) { - return new ErrorHandler(405, "Only POST method is allowed"); +void RestServer::onAccept( + beast::error_code ec, + boost::asio::ip::tcp::socket socket) { + if (!ec) { + // Create a new session to handle the request + std::make_shared(std::move(socket), functionPrefix_)->run(); + } else { + LOG(ERROR) << "Accept error: " << ec.message(); } - return new RestRequestHandler(functionPrefix_); + + // Accept another connection + doAccept(); } + } // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.h b/velox/functions/remote/server/RemoteFunctionRestService.h index 358965f753c0d..c9be6dbfdc3e0 100644 --- a/velox/functions/remote/server/RemoteFunctionRestService.h +++ b/velox/functions/remote/server/RemoteFunctionRestService.h @@ -16,56 +16,82 @@ #pragma once -#include +#include +#include +#include #include "velox/common/memory/Memory.h" -using namespace proxygen; - namespace facebook::velox::functions { -class ErrorHandler : public RequestHandler { - public: - explicit ErrorHandler(int statusCode, std::string message); - void onRequest(std::unique_ptr headers) noexcept override; - void onBody(std::unique_ptr) noexcept override; - void onEOM() noexcept override; - void onUpgrade(UpgradeProtocol protocol) noexcept override; - void requestComplete() noexcept override; - void onError(ProxygenError err) noexcept override; - - private: - int statusCode_; - std::string message_; -}; -class RestRequestHandler : public RequestHandler { +class RestRequestHandler { public: - explicit RestRequestHandler(const std::string& functionPrefix = "") - : functionPrefix_(functionPrefix) {} - void onRequest(std::unique_ptr headers) noexcept override; - void onBody(std::unique_ptr body) noexcept override; - void onEOM() noexcept override; - void onUpgrade(UpgradeProtocol protocol) noexcept override; - void requestComplete() noexcept override; - void onError(ProxygenError err) noexcept override; + explicit RestRequestHandler(const std::string& functionPrefix = ""); + + void handleRequest( + boost::beast::http::request>&& + req, + std::function>&&)> send); private: - std::unique_ptr body_; std::shared_ptr pool_{ memory::memoryManager()->addLeafPool()}; const std::string functionPrefix_; std::string functionName_; + std::string schema_; + std::string functionId_; + std::string version_; + + void processRequest( + const boost::beast::http::request< + boost::beast::http::vector_body>& req, + boost::beast::http::response>& + res); }; -class RestRequestHandlerFactory : public RequestHandlerFactory { +class RestSession : public std::enable_shared_from_this { public: - explicit RestRequestHandlerFactory(const std::string& functionPrefix = "") - : functionPrefix_(functionPrefix) {} - void onServerStart(folly::EventBase* evb) noexcept override; - void onServerStop() noexcept override; - RequestHandler* onRequest(RequestHandler*, HTTPMessage* msg) noexcept - override; + RestSession( + boost::asio::ip::tcp::socket socket, + const std::string& functionPrefix); + + void run(); private: - const std::string functionPrefix_; + void doRead(); + void onRead(boost::beast::error_code ec, std::size_t bytes_transferred); + void onWrite( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred); + void doClose(); + + boost::beast::tcp_stream stream_; + boost::beast::flat_buffer buffer_; + RestRequestHandler handler_; + boost::beast::http::request> req_; }; + +class RestServer { + public: + RestServer( + const std::string& address, + unsigned short port, + const std::string& functionPrefix); + + void run(); + + private: + void doAccept(); + void onAccept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket); + + std::string address_; + unsigned short port_; + std::string functionPrefix_; + boost::asio::io_context ioc_; + boost::asio::ip::tcp::acceptor acceptor_; +}; + } // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp index 4444c8c2d5645..7f54230da3841 100644 --- a/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp +++ b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp @@ -15,23 +15,18 @@ */ #include -#include #include "velox/common/memory/Memory.h" +#include "RemoteFunctionRestService.h" #include "velox/functions/Registerer.h" -#include "velox/functions/prestosql/Arithmetic.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/functions/remote/server/RemoteFunctionRestService.h" DEFINE_string( service_host, "127.0.0.1", - "Prefix to be added to the functions being registered"); + "Host address to bind the HTTP server"); -DEFINE_int32( - service_port, - 8321, - "Prefix to be added to the functions being registered"); +DEFINE_int32(service_port, 8321, "Port number to bind the HTTP server"); DEFINE_string( function_prefix, @@ -45,33 +40,16 @@ int main(int argc, char* argv[]) { FLAGS_logtostderr = true; memory::initializeMemoryManager({}); - // A remote function service should handle the function execution by its own. - // But we use Velox framework for quick prototype here + // Register scalar functions with the specified prefix functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); - // registerFunction( - // {"remote_plus"}); - // End of function registration - - LOG(INFO) << "Start HTTP Server at " << "http://" << FLAGS_service_host << ":" - << FLAGS_service_port; - - HTTPServerOptions options; - options.idleTimeout = std::chrono::milliseconds(60000); - options.handlerFactories = - RequestHandlerChain() - .addThen() - .build(); - options.h2cEnabled = true; - - std::vector IPs = { - {folly::SocketAddress(FLAGS_service_host, FLAGS_service_port, true), - HTTPServer::Protocol::HTTP}}; - HTTPServer server(std::move(options)); - server.bind(IPs); + // Start the HTTP server + LOG(INFO) << "Starting HTTP Server at " << "http://" << FLAGS_service_host + << ":" << FLAGS_service_port; - std::thread t([&]() { server.start(); }); + functions::RestServer server( + FLAGS_service_host, FLAGS_service_port, FLAGS_function_prefix); + server.run(); - t.join(); return 0; }