diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 56663a29d04b8..0b3b6bb6acd08 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -16,11 +16,23 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp) velox_link_libraries(velox_functions_remote_thrift_client PUBLIC remote_function_thrift FBThrift::thriftcpp2) +set(curl_SOURCE BUNDLED) +velox_resolve_dependency(curl) + +velox_add_library(velox_functions_remote_rest_client RestClient.cpp) +velox_link_libraries(velox_functions_remote_rest_client Folly::folly + ${CURL_LIBRARIES}) + velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( velox_functions_remote PUBLIC velox_expression + velox_memory + velox_exec + velox_vector + velox_presto_serializer velox_functions_remote_thrift_client + velox_functions_remote_rest_client velox_functions_remote_get_serde velox_type_fbhive Folly::folly) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 8458b84baaef2..aa58a9e75881f 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -16,34 +16,70 @@ #include "velox/functions/remote/client/Remote.h" +#include #include +#include +#include + +#include "velox/common/memory/ByteStream.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/remote/client/RestClient.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/type/fbhive/HiveTypeSerializer.h" #include "velox/vector/VectorStream.h" +using namespace folly; namespace facebook::velox::functions { namespace { std::string serializeType(const TypePtr& type) { - // Use hive type serializer. return type::fbhive::HiveTypeSerializer::serialize(type); } +std::string extractFunctionName(const std::string& input) { + size_t lastDot = input.find_last_of('.'); + if (lastDot != std::string::npos) { + return input.substr(lastDot + 1); + } + return input; +} + +std::string urlEncode(const std::string& value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + for (char c : value) { + if (isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '~') { + escaped << c; + } else { + escaped << '%' << std::setw(2) << int(static_cast(c)); + } + } + return escaped.str(); +} + class RemoteFunction : public exec::VectorFunction { public: RemoteFunction( const std::string& functionName, const std::vector& inputArgs, - const RemoteVectorFunctionMetadata& metadata) + const RemoteVectorFunctionMetadata& metadata, + std::unique_ptr httpClient = nullptr) : functionName_(functionName), - location_(metadata.location), - thriftClient_(getThriftClient(location_, &eventBase_)), - serdeFormat_(metadata.serdeFormat), - serde_(getSerde(serdeFormat_)) { + restClient_(httpClient ? std::move(httpClient) : getRestClient()), + metadata_(metadata) { + if (metadata.location.type() == typeid(SocketAddress)) { + location_ = boost::get(metadata.location); + thriftClient_ = getThriftClient(location_, &eventBase_); + } else if (metadata.location.type() == typeid(std::string)) { + url_ = boost::get(metadata.location); + } + std::vector types; types.reserve(inputArgs.size()); serializedInputTypes_.reserve(inputArgs.size()); @@ -62,7 +98,11 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const override { try { - applyRemote(rows, args, outputType, context, result); + if ((metadata_.location.type() == typeid(SocketAddress))) { + applyRemote(rows, args, outputType, context, result); + } else if (metadata_.location.type() == typeid(std::string)) { + applyRestRemote(rows, args, outputType, context, result); + } } catch (const VeloxRuntimeError&) { throw; } catch (const std::exception&) { @@ -71,6 +111,48 @@ class RemoteFunction : public exec::VectorFunction { } private: + void applyRestRemote( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { + try { + serializer::presto::PrestoVectorSerde serde; + auto remoteRowVector = std::make_shared( + context.pool(), + remoteInputType_, + BufferPtr{}, + rows.end(), + std::move(args)); + + std::unique_ptr requestBody = + std::make_unique(rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), &serde)); + + const std::string fullUrl = fmt::format( + "{}/v1/functions/{}/{}/{}/{}", + url_, + metadata_.schema.value_or("default"), + extractFunctionName(functionName_), + urlEncode(metadata_.functionId.value_or("default_function_id")), + metadata_.version.value_or("1")); + + std::unique_ptr responseBody = + restClient_->invokeFunction(fullUrl, std::move(requestBody)); + + auto outputRowVector = IOBufToRowVector( + *responseBody, ROW({outputType}), *context.pool(), &serde); + + result = outputRowVector->childAt(0); + } catch (const std::exception& e) { + VELOX_FAIL( + "Error while executing remote function '{}': {}", + functionName_, + e.what()); + } + } + void applyRemote( const SelectivityVector& rows, std::vector& args, @@ -97,11 +179,14 @@ class RemoteFunction : public exec::VectorFunction { auto requestInputs = request.inputs_ref(); requestInputs->rowCount_ref() = remoteRowVector->size(); - requestInputs->pageFormat_ref() = serdeFormat_; + requestInputs->pageFormat_ref() = metadata_.serdeFormat; // TODO: serialize only active rows. requestInputs->payload_ref() = rowVectorToIOBuf( - remoteRowVector, rows.end(), *context.pool(), serde_.get()); + remoteRowVector, + rows.end(), + *context.pool(), + getSerde(metadata_.serdeFormat).get()); try { thriftClient_->sync_invokeFunction(remoteResponse, request); @@ -117,12 +202,15 @@ class RemoteFunction : public exec::VectorFunction { remoteResponse.get_result().get_payload(), ROW({outputType}), *context.pool(), - serde_.get()); + getSerde(metadata_.serdeFormat).get()); result = outputRowVector->childAt(0); if (auto errorPayload = remoteResponse.get_result().errorPayload()) { auto errorsRowVector = IOBufToRowVector( - *errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get()); + *errorPayload, + ROW({VARCHAR()}), + *context.pool(), + getSerde(metadata_.serdeFormat).get()); auto errorsVector = errorsRowVector->childAt(0)->asFlatVector(); VELOX_CHECK(errorsVector, "Should be convertible to flat vector"); @@ -142,16 +230,14 @@ class RemoteFunction : public exec::VectorFunction { } const std::string functionName_; - folly::SocketAddress location_; - - folly::EventBase eventBase_; + EventBase eventBase_; std::unique_ptr thriftClient_; - remote::PageFormat serdeFormat_; - std::unique_ptr serde_; - - // Structures we construct once to cache: + std::unique_ptr restClient_; + SocketAddress location_; + std::string url_; RowTypePtr remoteInputType_; std::vector serializedInputTypes_; + const RemoteVectorFunctionMetadata metadata_; }; std::shared_ptr createRemoteFunction( @@ -169,7 +255,7 @@ void registerRemoteFunction( std::vector signatures, const RemoteVectorFunctionMetadata& metadata, bool overwrite) { - exec::registerStatefulVectorFunction( + registerStatefulVectorFunction( name, signatures, std::bind( diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index a6a1e773dc812..16fa1db37ae90 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include "velox/expression/VectorFunction.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" @@ -23,13 +24,29 @@ namespace facebook::velox::functions { struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { - /// Network address of the servr to communicate with. Note that this can hold - /// a network location (ip/port pair) or a unix domain socket path (see + /// URL of the HTTP/REST server for remote function. + /// Or Network address of the server to communicate with. Note that this can + /// hold a network location (ip/port pair) or a unix domain socket path (see /// SocketAddress::makeFromPath()). - folly::SocketAddress location; + boost::variant location; - /// The serialization format to be used + /// The serialization format to be used when sending data to the remote. remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE}; + + /// Optional schema defining the structure of the data or input/output types + /// involved in the remote function. This may include details such as column + /// names and data types. + std::optional schema; + + /// Optional identifier for the specific remote function to be invoked. + /// This can be useful when the same server hosts multiple functions, + /// and the client needs to specify which function to call. + std::optional functionId; + + /// Optional version information to be used when calling the remote function. + /// This can help in ensuring compatibility with a particular version of the + /// function if multiple versions are available on the server. + std::optional version; }; /// Registers a new remote function. It will use the meatadata defined in diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp new file mode 100644 index 0000000000000..228b853b46c29 --- /dev/null +++ b/velox/functions/remote/client/RestClient.cpp @@ -0,0 +1,112 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/client/RestClient.h" + +#include +#include + +#include "velox/common/base/Exceptions.h" + +using namespace folly; +namespace facebook::velox::functions { +namespace { +size_t readCallback(char* dest, size_t size, size_t nmemb, void* userp) { + auto* inputBufQueue = static_cast(userp); + size_t bufferSize = size * nmemb; + size_t totalCopied = 0; + + while (totalCopied < bufferSize && !inputBufQueue->empty()) { + auto buf = inputBufQueue->front(); + size_t remainingSize = bufferSize - totalCopied; + size_t copySize = std::min(remainingSize, buf->length()); + std::memcpy(dest + totalCopied, buf->data(), copySize); + totalCopied += copySize; + inputBufQueue->pop_front(); + } + + return totalCopied; +} +size_t writeCallback(char* ptr, size_t size, size_t nmemb, void* userData) { + auto* outputBuf = static_cast(userData); + size_t totalSize = size * nmemb; + auto buf = IOBuf::copyBuffer(ptr, totalSize); + outputBuf->append(std::move(buf)); + return totalSize; +} +} // namespace + +std::unique_ptr RestClient::invokeFunction( + const std::string& fullUrl, + std::unique_ptr requestPayload) { + try { + IOBufQueue inputBufQueue(IOBufQueue::cacheChainLength()); + inputBufQueue.append(std::move(requestPayload)); + + CURL* curl = curl_easy_init(); + if (!curl) { + VELOX_FAIL(fmt::format( + "Error initializing CURL: {}", + curl_easy_strerror(CURLE_FAILED_INIT))); + } + + curl_easy_setopt(curl, CURLOPT_URL, fullUrl.c_str()); + curl_easy_setopt(curl, CURLOPT_POST, 1L); + curl_easy_setopt(curl, CURLOPT_READFUNCTION, readCallback); + curl_easy_setopt(curl, CURLOPT_READDATA, &inputBufQueue); + curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback); + + IOBufQueue outputBuf(IOBufQueue::cacheChainLength()); + curl_easy_setopt(curl, CURLOPT_WRITEDATA, &outputBuf); + curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L); + + struct curl_slist* headers = nullptr; + headers = + curl_slist_append(headers, "Content-Type: application/X-presto-pages"); + headers = curl_slist_append(headers, "Accept: application/X-presto-pages"); + curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers); + + curl_easy_setopt( + curl, + CURLOPT_POSTFIELDSIZE, + static_cast(inputBufQueue.chainLength())); + + CURLcode res = curl_easy_perform(curl); + if (res != CURLE_OK) { + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + VELOX_FAIL(fmt::format( + "Error communicating with server: {}\nURL: {}\nCURL Error: {}", + curl_easy_strerror(res), + fullUrl.c_str(), + curl_easy_strerror(res))); + } + + curl_slist_free_all(headers); + curl_easy_cleanup(curl); + + return outputBuf.move(); + + } catch (const std::exception& e) { + VELOX_FAIL(fmt::format("Exception during CURL request: {}", e.what())); + } +} + +std::unique_ptr getRestClient() { + return std::make_unique(); +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h new file mode 100644 index 0000000000000..b2313021681a9 --- /dev/null +++ b/velox/functions/remote/client/RestClient.h @@ -0,0 +1,63 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace facebook::velox::functions { + +/// @brief Abstract interface for an HTTP client. +/// Provides a method to invoke a function by sending an HTTP request +/// and receiving a response, both in Presto's serialized wire format. +class HttpClient { + public: + virtual ~HttpClient() = default; + + /// @brief Invokes a function over HTTP. + /// @param url The endpoint URL to send the request to. + /// @param requestPayload The request payload in Presto's serialized wire + /// format. + /// @return A unique pointer to the response payload in Presto's serialized + /// wire format. + virtual std::unique_ptr invokeFunction( + const std::string& url, + std::unique_ptr requestPayload) = 0; +}; + +/// @brief Concrete implementation of HttpClient using REST. +/// Handles HTTP communication by sending requests and receiving responses +/// using RESTful APIs with payloads in Presto's serialized wire format. +class RestClient : public HttpClient { + public: + /// @brief Invokes a function over HTTP using REST. + /// @param url The endpoint URL to send the request to. + /// @param requestPayload The request payload in Presto's serialized wire + /// format. + /// @return A unique pointer to the response payload in Presto's serialized + /// wire format. + std::unique_ptr invokeFunction( + const std::string& url, + std::unique_ptr requestPayload) override; +}; + +/// @brief Factory function to create an instance of RestClient. +/// @return A unique pointer to an HttpClient implementation. +std::unique_ptr getRestClient(); + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/CMakeLists.txt b/velox/functions/remote/client/tests/CMakeLists.txt index 1659ad9d7e5a3..35fd732e37d9c 100644 --- a/velox/functions/remote/client/tests/CMakeLists.txt +++ b/velox/functions/remote/client/tests/CMakeLists.txt @@ -27,3 +27,20 @@ target_link_libraries( GTest::gmock GTest::gtest GTest::gtest_main) + +add_executable(velox_functions_remote_client_rest_test + RemoteFunctionRestTest.cpp) + +add_test(velox_functions_remote_client_rest_test + velox_functions_remote_client_rest_test) + +target_link_libraries( + velox_functions_remote_client_rest_test + velox_functions_remote_rest_client + velox_functions_remote_server_rest + velox_functions_remote + velox_functions_test_lib + velox_exec_test_lib + GTest::gmock + GTest::gtest + GTest::gtest_main) diff --git a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp new file mode 100644 index 0000000000000..2db44037a9a26 --- /dev/null +++ b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp @@ -0,0 +1,212 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/lib/CheckedArithmetic.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +using ::facebook::velox::test::assertEqualVectors; + +namespace facebook::velox::functions { +namespace { + +class RemoteFunctionRestTest + : public test::FunctionBaseTest, + public testing::WithParamInterface { + public: + void SetUp() override { + initializeServer(); + registerRemoteFunctions(); + } + + // Registers a few remote functions to be used in this test. + void registerRemoteFunctions() const { + RemoteVectorFunctionMetadata metadata; + metadata.serdeFormat = GetParam(); + metadata.location = location_; + + auto absSignature = {exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("integer") + .build()}; + registerRemoteFunction("remote_abs", absSignature, metadata); + + auto plusSignatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerRemoteFunction("remote_plus", plusSignatures, metadata); + + RemoteVectorFunctionMetadata wrongMetadata = metadata; + wrongMetadata.location = folly::SocketAddress(); // empty address. + registerRemoteFunction("remote_wrong_port", plusSignatures, wrongMetadata); + + auto divSignatures = {exec::FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build()}; + registerRemoteFunction("remote_divide", divSignatures, metadata); + + auto substrSignatures = {exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build()}; + registerRemoteFunction("remote_substr", substrSignatures, metadata); + + // Registers the actual function under a different prefix. This is only + // needed for tests since the HTTP service runs in the same process. + registerFunction( + {remotePrefix_ + ".remote_abs"}); + registerFunction( + {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_divide"}); + registerFunction( + {remotePrefix_ + ".remote_substr"}); + } + + void initializeServer() { + // Adjusted for Boost.Beast server; the server is started in the main + // thread. + + // Start the server in a separate thread + serverThread_ = std::make_unique([this]() { + std::string serviceHost = "127.0.0.1"; + int32_t servicePort = 8321; + std::string functionPrefix = remotePrefix_; + + boost::asio::io_context ioc{1}; + + std::make_shared( + ioc, + boost::asio::ip::tcp::endpoint( + boost::asio::ip::make_address(serviceHost), servicePort), + functionPrefix) + ->run(); + + ioc.run(); + }); + + VELOX_CHECK(waitForRunning(), "Unable to initialize HTTP server."); + LOG(INFO) << "HTTP server is up and running at " << location_; + } + + ~RemoteFunctionRestTest() override { + // Signal the server thread to stop + serverThread_->detach(); + LOG(INFO) << "HTTP server stopped."; + } + + private: + bool waitForRunning() const { + for (size_t i = 0; i < 100; ++i) { + using boost::asio::ip::tcp; + boost::asio::io_context io_context; + + tcp::socket socket(io_context); + tcp::resolver resolver(io_context); + + try { + boost::asio::connect( + socket, resolver.resolve("127.0.0.1", std::to_string(8321))); + return true; + } catch (std::exception& e) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + return false; + } + + std::unique_ptr serverThread_; + + std::string location_{("http://127.0.0.1:8321")}; + const std::string remotePrefix_{"remote"}; +}; + +TEST_P(RemoteFunctionRestTest, absolute) { + auto inputVector = makeFlatVector({-10, -20}); + auto results = evaluate>( + "remote_abs(c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({10, 20}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, simple) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto results = evaluate>( + "remote_plus(c0, c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({2, 4, 6, 8, 10}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, string) { + auto inputVector = + makeFlatVector({"hello", "my", "remote", "world"}); + auto inputVector1 = makeFlatVector({2, 1, 3, 5}); + auto results = evaluate>( + "remote_substr(c0, c1)", makeRowVector({inputVector, inputVector1})); + + auto expected = makeFlatVector({"ello", "my", "mote", "d"}); + assertEqualVectors(expected, results); +} + +TEST_P(RemoteFunctionRestTest, connectionError) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto func = [&]() { + evaluate>( + "remote_wrong_port(c0, c0)", makeRowVector({inputVector})); + }; + + // Check it throws and that the exception has the "connection refused" + // substring. + EXPECT_THROW(func(), VeloxRuntimeError); + try { + func(); + } catch (const VeloxRuntimeError& e) { + EXPECT_THAT(e.message(), testing::HasSubstr("Channel is !good()")); + } +} + +VELOX_INSTANTIATE_TEST_SUITE_P( + RemoteFunctionRestTestFixture, + RemoteFunctionRestTest, + ::testing::Values(remote::PageFormat::PRESTO_PAGE)); + +} // namespace +} // namespace facebook::velox::functions + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/functions/remote/server/CMakeLists.txt b/velox/functions/remote/server/CMakeLists.txt index ff2afa0fed6a8..3d958a339c798 100644 --- a/velox/functions/remote/server/CMakeLists.txt +++ b/velox/functions/remote/server/CMakeLists.txt @@ -24,3 +24,18 @@ add_executable(velox_functions_remote_server_main RemoteFunctionServiceMain.cpp) target_link_libraries( velox_functions_remote_server_main velox_functions_remote_server velox_functions_prestosql) + +add_library(velox_functions_remote_server_rest RemoteFunctionRestService.cpp) +target_link_libraries( + velox_functions_remote_server_rest + velox_functions_remote_get_serde + velox_type_fbhive + velox_memory + velox_functions_prestosql) + +add_executable(velox_functions_remote_server_rest_main + RemoteFunctionServiceRestMain.cpp) + +target_link_libraries( + velox_functions_remote_server_rest_main velox_functions_remote_server_rest + velox_functions_prestosql) diff --git a/velox/functions/remote/server/RemoteFunctionHelper.h b/velox/functions/remote/server/RemoteFunctionHelper.h new file mode 100644 index 0000000000000..e13df63265737 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionHelper.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include "velox/expression/Expr.h" +#include "velox/type/Type.h" +#include "velox/type/fbhive/HiveTypeParser.h" + +namespace facebook::velox::functions { +inline std::string getFunctionName( + const std::string& prefix, + const std::string& functionName) { + return prefix.empty() ? functionName + : fmt::format("{}.{}", prefix, functionName); +} + +inline TypePtr deserializeType(const std::string& input) { + // Use hive type parser/serializer. + return type::fbhive::HiveTypeParser().parse(input); +} + +inline RowTypePtr deserializeArgTypes( + const std::vector& argTypes) { + const size_t argCount = argTypes.size(); + + std::vector argumentTypes; + std::vector typeNames; + argumentTypes.reserve(argCount); + typeNames.reserve(argCount); + + for (size_t i = 0; i < argCount; ++i) { + argumentTypes.emplace_back(deserializeType(argTypes[i])); + typeNames.emplace_back(fmt::format("c{}", i)); + } + return ROW(std::move(typeNames), std::move(argumentTypes)); +} + +inline std::vector getExpressions( + const RowTypePtr& inputType, + const TypePtr& returnType, + const std::string& functionName) { + std::vector inputs; + for (size_t i = 0; i < inputType->size(); ++i) { + inputs.push_back(std::make_shared( + inputType->childAt(i), inputType->nameOf(i))); + } + + return {std::make_shared( + returnType, std::move(inputs), functionName)}; +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.cpp b/velox/functions/remote/server/RemoteFunctionRestService.cpp new file mode 100644 index 0000000000000..a5c6824593936 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.cpp @@ -0,0 +1,279 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "RemoteFunctionRestService.h" + +#include +#include +#include "velox/expression/Expr.h" +#include "velox/functions/remote/server/RemoteFunctionHelper.h" +#include "velox/vector/VectorStream.h" + +namespace facebook::velox::functions { + +namespace { + +struct InternalFunctionSignature { + std::vector argumentTypes; + std::string returnType; +}; + +std::map internalFunctionSignatureMap = + { + {"remote_abs", {{"integer"}, "integer"}}, + {"remote_plus", {{"bigint", "bigint"}, "bigint"}}, + {"remote_divide", {{"double", "double"}, "double"}}, + {"remote_substr", {{"varchar", "integer"}, "varchar"}}, + // Add more functions here as needed, registerRemoteFunction should be + // called to use the functions mentioned in this map +}; + +} // namespace + +session::session( + boost::asio::ip::tcp::socket socket, + std::string functionPrefix) + : socket_(std::move(socket)), functionPrefix_(std::move(functionPrefix)) {} + +void session::run() { + do_read(); +} + +void session::do_read() { + auto self = shared_from_this(); + boost::beast::http::async_read( + socket_, + buffer_, + req_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->on_read(ec, bytes_transferred); + }); +} + +void session::on_read( + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec == boost::beast::http::error::end_of_stream) { + return do_close(); + } + + if (ec) { + LOG(ERROR) << "Read error: " << ec.message(); + return; + } + + handle_request(std::move(req_)); +} + +void session::handle_request( + boost::beast::http::request req) { + res_.version(req.version()); + res_.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING); + + if (req.method() != boost::beast::http::verb::post) { + res_.result(boost::beast::http::status::method_not_allowed); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = "Only POST method is allowed"; + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->on_write(true, ec, bytes_transferred); + }); + return; + } + + std::string path = req.target(); + + // Expected path format: + // /v1/functions/{schema}/{functionName}/{functionId}/{version} Split the + // path by '/' + std::vector pathComponents; + folly::split('/', path, pathComponents); + + std::string functionName; + if (pathComponents.size() >= 7 && pathComponents[1] == "v1" && + pathComponents[2] == "functions") { + functionName = pathComponents[4]; + } else { + res_.result(boost::beast::http::status::bad_request); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = "Invalid request path"; + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->on_write(true, ec, bytes_transferred); + }); + return; + } + + try { + const auto& functionSignature = + internalFunctionSignatureMap.at(functionName); + + auto inputType = deserializeArgTypes(functionSignature.argumentTypes); + auto returnType = deserializeType(functionSignature.returnType); + + serializer::presto::PrestoVectorSerde serde; + auto inputBuffer = folly::IOBuf::copyBuffer(req.body()); + auto inputVector = + IOBufToRowVector(*inputBuffer, inputType, *pool_, &serde); + + const vector_size_t numRows = inputVector->size(); + SelectivityVector rows{numRows}; + + // Expression boilerplate. + auto queryCtx = core::QueryCtx::create(); + core::ExecCtx execCtx{pool_.get(), queryCtx.get()}; + exec::ExprSet exprSet{ + getExpressions( + inputType, + returnType, + getFunctionName(functionPrefix_, functionName)), + &execCtx}; + exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); + + std::vector expressionResult; + exprSet.eval(rows, evalCtx, expressionResult); + + // Create output vector. + auto outputRowVector = std::make_shared( + pool_.get(), ROW({returnType}), BufferPtr(), numRows, expressionResult); + + auto payload = + rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, &serde); + + res_.result(boost::beast::http::status::ok); + res_.set( + boost::beast::http::field::content_type, "application/octet-stream"); + res_.body() = payload.moveToFbString().toStdString(); + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->on_write(false, ec, bytes_transferred); + }); + + } catch (const std::exception& ex) { + LOG(ERROR) << ex.what(); + res_.result(boost::beast::http::status::internal_server_error); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = ex.what(); + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->on_write(true, ec, bytes_transferred); + }); + } +} + +void session::on_write( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec) { + LOG(ERROR) << "Write error: " << ec.message(); + return; + } + + if (close) { + return do_close(); + } + + req_ = {}; + + do_read(); +} + +void session::do_close() { + boost::beast::error_code ec; + socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec); +} + +listener::listener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint, + std::string functionPrefix) + : ioc_(ioc), acceptor_(ioc), functionPrefix_(std::move(functionPrefix)) { + boost::beast::error_code ec; + + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + LOG(ERROR) << "Open error: " << ec.message(); + return; + } + + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + LOG(ERROR) << "Set_option error: " << ec.message(); + return; + } + + acceptor_.bind(endpoint, ec); + if (ec) { + LOG(ERROR) << "Bind error: " << ec.message(); + return; + } + + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + LOG(ERROR) << "Listen error: " << ec.message(); + return; + } +} + +void listener::run() { + do_accept(); +} + +void listener::do_accept() { + acceptor_.async_accept( + [self = shared_from_this()]( + boost::beast::error_code ec, boost::asio::ip::tcp::socket socket) { + self->on_accept(ec, std::move(socket)); + }); +} + +void listener::on_accept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket) { + if (ec) { + LOG(ERROR) << "Accept error: " << ec.message(); + } else { + std::make_shared(std::move(socket), functionPrefix_)->run(); + } + do_accept(); +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.h b/velox/functions/remote/server/RemoteFunctionRestService.h new file mode 100644 index 0000000000000..47905f5fecd33 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include "velox/common/memory/Memory.h" + +namespace facebook::velox::functions { + +/// @brief Manages an individual HTTP session. +/// Handles reading HTTP requests, processing them, and sending responses. +/// This class re-hosts Velox functions and allows testing their functionality. +class session : public std::enable_shared_from_this { + public: + session(boost::asio::ip::tcp::socket socket, std::string functionPrefix); + + /// Starts the session by initiating a read operation. + void run(); + + private: + // Initiates an asynchronous read operation. + void do_read(); + + // Called when a read operation completes. + void on_read(boost::beast::error_code ec, std::size_t bytes_transferred); + + // Processes the HTTP request and prepares a response. + void handle_request( + boost::beast::http::request req); + + // Called when a write operation completes. + void on_write( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred); + + // Closes the socket connection. + void do_close(); + + boost::asio::ip::tcp::socket socket_; + boost::beast::flat_buffer buffer_; + std::string functionPrefix_; + boost::beast::http::request req_; + boost::beast::http::response res_; + std::shared_ptr pool_{ + memory::memoryManager()->addLeafPool()}; +}; + +/// @brief Listens for incoming TCP connections and creates sessions. +/// Sets up a TCP acceptor to listen for client connections, +/// creating a new session for each accepted connection. +class listener : public std::enable_shared_from_this { + public: + listener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint, + std::string functionPrefix); + + /// Starts accepting incoming connections. + void run(); + + private: + // Initiates an asynchronous accept operation. + void do_accept(); + + // Called when an accept operation completes. + void on_accept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket); + + boost::asio::io_context& ioc_; + boost::asio::ip::tcp::acceptor acceptor_; + std::string functionPrefix_; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionService.cpp b/velox/functions/remote/server/RemoteFunctionService.cpp index 2cc7a0abac129..58365797bc4a9 100644 --- a/velox/functions/remote/server/RemoteFunctionService.cpp +++ b/velox/functions/remote/server/RemoteFunctionService.cpp @@ -18,54 +18,11 @@ #include "velox/common/base/Exceptions.h" #include "velox/expression/Expr.h" #include "velox/functions/remote/if/GetSerde.h" +#include "velox/functions/remote/server/RemoteFunctionHelper.h" #include "velox/type/fbhive/HiveTypeParser.h" #include "velox/vector/VectorStream.h" namespace facebook::velox::functions { -namespace { - -std::string getFunctionName( - const std::string& prefix, - const std::string& functionName) { - return prefix.empty() ? functionName - : fmt::format("{}.{}", prefix, functionName); -} - -TypePtr deserializeType(const std::string& input) { - // Use hive type parser/serializer. - return type::fbhive::HiveTypeParser().parse(input); -} - -RowTypePtr deserializeArgTypes(const std::vector& argTypes) { - const size_t argCount = argTypes.size(); - - std::vector argumentTypes; - std::vector typeNames; - argumentTypes.reserve(argCount); - typeNames.reserve(argCount); - - for (size_t i = 0; i < argCount; ++i) { - argumentTypes.emplace_back(deserializeType(argTypes[i])); - typeNames.emplace_back(fmt::format("c{}", i)); - } - return ROW(std::move(typeNames), std::move(argumentTypes)); -} - -} // namespace - -std::vector getExpressions( - const RowTypePtr& inputType, - const TypePtr& returnType, - const std::string& functionName) { - std::vector inputs; - for (size_t i = 0; i < inputType->size(); ++i) { - inputs.push_back(std::make_shared( - inputType->childAt(i), inputType->nameOf(i))); - } - - return {std::make_shared( - returnType, std::move(inputs), functionName)}; -} void RemoteFunctionServiceHandler::handleErrors( apache::thrift::field_ref result, diff --git a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp index c92ab9231d114..92ff2791bb1f8 100644 --- a/velox/functions/remote/server/RemoteFunctionServiceMain.cpp +++ b/velox/functions/remote/server/RemoteFunctionServiceMain.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "velox/functions/prestosql/StringFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/remote/server/RemoteFunctionService.h" @@ -36,7 +37,7 @@ DEFINE_string( DEFINE_string( function_prefix, - "json.test_schema.", + "remote.schema.", "Prefix to be added to the functions being registered"); using namespace ::facebook::velox; @@ -46,11 +47,14 @@ int main(int argc, char* argv[]) { folly::Init init{&argc, &argv, false}; FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + // Always registers all Presto functions and make them available under a // certain prefix/namespace. LOG(INFO) << "Registering Presto functions"; functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + std::remove(FLAGS_uds_path.c_str()); folly::SocketAddress location{ folly::SocketAddress::makeFromPath(FLAGS_uds_path)}; diff --git a/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp new file mode 100644 index 0000000000000..c988aed981e0f --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionServiceRestMain.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "RemoteFunctionRestService.h" +#include "velox/common/memory/Memory.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +DEFINE_string(service_host, "127.0.0.1", "Host to bind the service to"); + +DEFINE_int32(service_port, 8321, "Port to bind the service to"); + +DEFINE_string( + function_prefix, + "remote.schema", + "Prefix to be added to the functions being registered"); + +using namespace ::facebook::velox; + +int main(int argc, char* argv[]) { + folly::Init init(&argc, &argv); + FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + + functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + boost::asio::io_context ioc{1}; + + std::make_shared( + ioc, + boost::asio::ip::tcp::endpoint( + boost::asio::ip::make_address(FLAGS_service_host), + FLAGS_service_port), + FLAGS_function_prefix) + ->run(); + + ioc.run(); + + return 0; +}