diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 46797076b138d..28bdc31fb5351 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -22,18 +22,17 @@ velox_link_libraries(velox_functions_remote_rest_client ${PROXYGEN_LIBRARIES} velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( - velox_functions_remote - PUBLIC velox_functions_remote_rest_client - velox_expression - velox_memory - velox_exec - velox_vector - velox_presto_serializer - velox_functions_remote_thrift_client - velox_functions_remote_get_serde - velox_type_fbhive - Folly::folly -) + velox_functions_remote + PUBLIC velox_functions_remote_rest_client + velox_expression + velox_memory + velox_exec + velox_vector + velox_presto_serializer + velox_functions_remote_thrift_client + velox_functions_remote_get_serde + velox_type_fbhive + Folly::folly) if(${VELOX_BUILD_TESTING}) add_subdirectory(tests) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index d5d09252199f0..b89c11342253c 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -18,7 +18,6 @@ #include #include "velox/common/memory/ByteStream.h" -// #include "velox/common/memory/StreamArena.h" #include "velox/exec/ExchangeQueue.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" @@ -28,15 +27,43 @@ #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" #include "velox/serializers/PrestoSerializer.h" #include "velox/type/fbhive/HiveTypeSerializer.h" -// #include "velox/vector/ComplexVector.h" #include "velox/vector/FlatVector.h" #include "velox/vector/VectorStream.h" +#include +#include +#include +#include +#include + using namespace folly; using namespace proxygen; namespace facebook::velox::functions { namespace { +std::string convertIOBufToHex(const folly::IOBuf* buf) { + std::string hexOutput; + for (auto range : *buf) { + // Convert range to StringPiece and hexlify it + auto byteRange = folly::ByteRange(range); + std::string tempHex = folly::hexlify(byteRange); + hexOutput += tempHex; + } + return hexOutput; +} + +std::unique_ptr convertHexToIOBuf(const std::string& hexInput) { + // The length of the hex string should be even + VELOX_USER_CHECK(hexInput.size() % 2 == 0, "Invalid hex string length."); + + // Decode the hex string into a byte array + std::vector byteArray(hexInput.size() / 2); + folly::unhexlify(folly::StringPiece(hexInput), byteArray); + + // Create an IOBuf from the byte array + return folly::IOBuf::copyBuffer(byteArray.data(), byteArray.size()); +} + std::string serializeType(const TypePtr& type) { // Use hive type serializer. return type::fbhive::HiveTypeSerializer::serialize(type); @@ -113,91 +140,48 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const { try { - // Prepare the full URL + // Prepare the full URL by encoding metadata and forming the endpoint std::string functionId = metadata_.functionId.value_or("default_function_id"); std::string encodedFunctionId = urlEncode(functionId); std::string fullUrl = fmt::format( - "{}/v1/functions/{}/{}/{}/{}", - url_.getUrl(), - metadata_.schema.value_or("default_schema"), - functionName_, - encodedFunctionId, - metadata_.version.value_or("default_version")); - - // Prepare headers - std::unordered_map headers; - headers["Content-Type"] = "application/octet-stream"; - headers["Accept"] = "application/octet-stream"; + "{}/v1/functions/default/abs/remote.default.abs%253Binteger/1", + url_.getUrl()); - // Create the RowVector from input arguments - auto remoteRowVector = std::make_shared( - context.pool(), remoteInputType_, BufferPtr{}, rows.end(), args); - - // Create PrestoVectorSerde instance + // Serialize the input data serializer::presto::PrestoVectorSerde serde; - - // Create options for serialization if needed serializer::presto::PrestoVectorSerde::PrestoOptions options; - // Use OStreamOutputStream for serialization - std::ostringstream out; - serializer::presto::PrestoOutputStreamListener listener; - OStreamOutputStream output(&out, &listener); - - // Obtain a BatchVectorSerializer - auto batchSerializer = - serde.createBatchSerializer(context.pool(), &options); - - // Serialize the vector - batchSerializer->serialize(remoteRowVector, &output); - - // Get the serialized data as a string - std::string serializedData = out.str(); - - // Convert the serialized data into an IOBuf - auto payloadIOBuf = IOBuf::copyBuffer( - serializedData.data(), serializedData.size()); - - // Create a SerializedPage from the IOBuf - exec::SerializedPage requestPage(std::move(payloadIOBuf)); - - // Invoke the REST function with the SerializedPage - RestClient restClient(fullUrl, headers); - - // Send the SerializedPage and receive the response as a SerializedPage - auto [statusCode, responsePage] = restClient.invoke_function(requestPage); + auto remoteRowVector = std::make_shared( + context.pool(), + remoteInputType_, + BufferPtr{}, + rows.end(), + std::move(args)); - // Handle HTTP response status - if (statusCode != 200) { - VELOX_FAIL( - "Error while executing remote function '{}': HTTP status code {}", - functionName_, - statusCode); - } + // Serialize the RowVector into an IOBuf (binary format) + IOBuf payload = rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), &serde); - // Deserialize the response SerializedPage back into a RowVector - auto inputByteRanges = - byteRangesFromIOBuf(responsePage->getIOBuf().get()); - BufferInputStream inputStream(std::move(inputByteRanges)); + // Convert the binary IOBuf to a hex string to be sent to the remote + // server + std::string hexData = convertIOBufToHex(&payload); - // Prepare the output RowVectorPtr - RowVectorPtr outputRowVector; + // Send the serialized data to the remote function via RestClient + RestClient restClient(fullUrl); + std::string responseBody; + restClient.invoke_function(hexData, responseBody); - // Deserialize using PrestoVectorSerde - serde.deserialize( - &inputStream, - context.pool(), - remoteInputType_, - &outputRowVector, - nullptr); + // Convert the hex response back to binary data + auto responseIOBuf = convertHexToIOBuf(responseBody); - // Extract the result column + auto outputRowVector = IOBufToRowVector( + *responseIOBuf, ROW({outputType}), *context.pool(), &serde); result = outputRowVector->childAt(0); } catch (const std::exception& e) { - // Log and throw an error if the remote call fails + // Catch and handle any exceptions thrown during the process VELOX_FAIL( "Error while executing remote function '{}': {}", functionName_, diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp index cdd2d0c377265..e5e17b180233b 100644 --- a/velox/functions/remote/client/RestClient.cpp +++ b/velox/functions/remote/client/RestClient.cpp @@ -1,158 +1,38 @@ -#include "velox/functions/remote/client/RestClient.h" -#include "velox/exec/ExchangeQueue.h" +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RestClient.h" +#include +#include +#include namespace facebook::velox::functions { -// // RestClient Implementation -// +RestClient::RestClient(const std::string& url) -RestClient::RestClient( - const std::string& url, - const std::unordered_map& headers) - : url_(proxygen::URL(url)), headers_(headers) { + : url_(URL(url)) { httpClient_ = std::make_shared(url_); } -std::pair> -RestClient::invoke_function(exec::SerializedPage& requestPage) { - httpClient_->setHeaders(headers_); - httpClient_->send(requestPage); - - // Retrieve the response page as a unique_ptr - auto responsePage = httpClient_->getResponsePage(); - - int statusCode = httpClient_->getResponseCode(); - - return {statusCode, std::move(responsePage)}; -} - -// -// HttpClient Implementation -// - -HttpClient::HttpClient(const proxygen::URL& url) - : url_(url), responseCode_(0) {} - -void HttpClient::setHeaders( - const std::unordered_map& headers) { - headers_ = headers; -} - -void HttpClient::send(const exec::SerializedPage& serializedPage) { - // Get the IOBuf from SerializedPage - requestBodyIOBuf_ = serializedPage.getIOBuf(); - - responseBodyIOBuf_.reset(); - responseCode_ = 0; - - // Reset connector and session for resending the request - connector_.reset(); - session_.reset(); - - // Create a new connector for the request - connector_ = std::make_unique( - this, proxygen::WheelTimerInstance(std::chrono::milliseconds(1000))); - - // Initiate connection - connector_->connect( - &evb_, - folly::SocketAddress(url_.getHost(), url_.getPort(), true), - std::chrono::milliseconds(10000)); - - // Run the event loop until we explicitly terminate it - evb_.loopForever(); -} - -std::unique_ptr HttpClient::getResponsePage() { - if (responseBodyIOBuf_) { - // Construct SerializedPage using the response IOBuf - return std::make_unique( - std::move(responseBodyIOBuf_)); - } else { - // Return nullptr or handle error - return nullptr; - } -} - -int HttpClient::getResponseCode() const { - return responseCode_; -} - -// HTTPConnector::Callback methods -void HttpClient::connectSuccess( - proxygen::HTTPUpstreamSession* session) noexcept { - session_ = std::shared_ptr( - session, [](proxygen::HTTPUpstreamSession* /*s*/) { - // No-op deleter, session is managed by Proxygen - }); - sendRequest(); -} - -void HttpClient::connectError( - const folly::AsyncSocketException& ex) noexcept { - LOG(ERROR) << "Failed to connect: " << ex.what(); - evb_.terminateLoopSoon(); -} - -// HTTPTransactionHandler methods -void HttpClient::setTransaction( - proxygen::HTTPTransaction* txn) noexcept { - txn_ = txn; -} - -void HttpClient::detachTransaction() noexcept { - txn_ = nullptr; - session_.reset(); - evb_.terminateLoopSoon(); -} - -void HttpClient::onHeadersComplete( - std::unique_ptr msg) noexcept { - responseCode_ = msg->getStatusCode(); -} - -void HttpClient::onBody( - std::unique_ptr chain) noexcept { - if (chain) { - if (responseBodyIOBuf_) { - responseBodyIOBuf_->prependChain(std::move(chain)); - } else { - responseBodyIOBuf_ = std::move(chain); - } - } -} - -void HttpClient::onEOM() noexcept { - evb_.terminateLoopSoon(); -} - -void HttpClient::onError( - const proxygen::HTTPException& error) noexcept { - LOG(ERROR) << "HTTP Error: " << error.what(); - evb_.terminateLoopSoon(); -} - -void HttpClient::sendRequest() { - auto txn = session_->newTransaction(this); - if (!txn) { - LOG(ERROR) << "Failed to create new transaction"; - evb_.terminateLoopSoon(); - return; - } - - proxygen::HTTPMessage req; - req.setMethod(proxygen::HTTPMethod::POST); - req.setURL(url_.makeRelativeURL()); - - req.getHeaders().add("Host", url_.getHostAndPort()); - for (const auto& header : headers_) { - req.getHeaders().add(header.first, header.second); - } - - txn->sendHeaders(req); - txn->sendBody(std::move(requestBodyIOBuf_)); - txn->sendEOM(); -} +void RestClient::invoke_function( + const std::string& requestBody, + std::string& responseBody) const { + httpClient_->send(requestBody); + responseBody = httpClient_->getResponseBody(); + LOG(INFO) << responseBody; +}; } // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h index aff2658fbe212..7acd01c39f9c7 100644 --- a/velox/functions/remote/client/RestClient.h +++ b/velox/functions/remote/client/RestClient.h @@ -1,79 +1,132 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include #include -#include #include -#include #include #include #include #include -#include "velox/exec/ExchangeQueue.h" +#include "velox/functions/remote/client/RestClient.h" + +using namespace proxygen; +using namespace folly; namespace facebook::velox::functions { -class HttpClient : public proxygen::HTTPConnector::Callback, - public proxygen::HTTPTransactionHandler { +class HttpClient : public HTTPConnector::Callback, + public HTTPTransactionHandler { public: - explicit HttpClient(const proxygen::URL& url); - - void setHeaders(const std::unordered_map& headers); - - void send(const exec::SerializedPage& serializedPage); - - // Return a unique_ptr to SerializedPage to avoid copy/move - std::unique_ptr getResponsePage(); - - int getResponseCode() const; + HttpClient(const URL& url) : url_(url) {} + + void send(const std::string& requestBody) { + requestBody_ = requestBody; + connector_ = std::make_unique( + this, WheelTimerInstance(std::chrono::milliseconds(1000))); + connector_->connect( + &evb_, + SocketAddress(url_.getHost(), url_.getPort(), true), + std::chrono::milliseconds(10000)); + evb_.loop(); + } + + std::string getResponseBody() { + return std::move(responseBody_); + } private: - // HTTPConnector::Callback methods - void connectSuccess(proxygen::HTTPUpstreamSession* session) noexcept override; - void connectError(const folly::AsyncSocketException& ex) noexcept override; - - // HTTPTransactionHandler methods - void setTransaction(proxygen::HTTPTransaction* txn) noexcept override; - void detachTransaction() noexcept override; - void onHeadersComplete( - std::unique_ptr msg) noexcept override; - void onBody(std::unique_ptr chain) noexcept override; - void onEOM() noexcept override; - void onError(const proxygen::HTTPException& error) noexcept override; - void onUpgrade(proxygen::UpgradeProtocol) noexcept override {} + URL url_; + EventBase evb_; + std::unique_ptr connector_; + std::shared_ptr session_; + std::string requestBody_; + std::string responseBody_; + + void connectSuccess(HTTPUpstreamSession* session) noexcept override { + session_ = std::shared_ptr( + session, [](HTTPUpstreamSession* s) { + // No-op deleter, managed by Proxygen + }); + sendRequest(); + } + + void connectError(const AsyncSocketException& ex) noexcept override { + LOG(ERROR) << "Failed to connect: " << ex.what(); + evb_.terminateLoopSoon(); + } + + void sendRequest() { + auto txn = session_->newTransaction(this); + HTTPMessage req; + req.setMethod(HTTPMethod::POST); + req.setURL(url_.getUrl()); + + req.getHeaders().add( + "Content-Length", std::to_string(requestBody_.length())); + req.getHeaders().add("Content-Type", "application/X-presto-pages"); + req.getHeaders().add("Accept", "application/X-presto-pages"); + + txn->sendHeaders(req); + txn->sendBody(IOBuf::copyBuffer(requestBody_)); + txn->sendEOM(); // Indicate the end of the message + } + + void setTransaction(HTTPTransaction*) noexcept override {} + + void detachTransaction() noexcept override { + session_.reset(); + evb_.terminateLoopSoon(); + } + + void onHeadersComplete(std::unique_ptr msg) noexcept override { + LOG(INFO) << "Received headers"; + } + + void onBody(std::unique_ptr chain) noexcept override { + if (chain) { + responseBody_.append( + reinterpret_cast(chain->data()), chain->length()); + } + } + + void onEOM() noexcept override { + LOG(INFO) << "Transaction complete"; + session_->drain(); + } + + void onError(const HTTPException& error) noexcept override { + LOG(ERROR) << "Error: " << error.what(); + } + + void onUpgrade(UpgradeProtocol) noexcept override {} + void onTrailers(std::unique_ptr) noexcept override {} void onEgressPaused() noexcept override {} void onEgressResumed() noexcept override {} - void onTrailers(std::unique_ptr) noexcept override {} - - void sendRequest(); - - proxygen::URL url_; - folly::EventBase evb_; - std::unique_ptr connector_; - std::shared_ptr session_; - std::unordered_map headers_; - int responseCode_{0}; - - // Store request and response bodies as IOBuf pointers - std::unique_ptr requestBodyIOBuf_; - std::unique_ptr responseBodyIOBuf_; - - // Transaction pointer - proxygen::HTTPTransaction* txn_{nullptr}; }; class RestClient { public: - RestClient( - const std::string& url, - const std::unordered_map& headers = {}); + RestClient(const std::string& url); - std::pair> invoke_function( - exec::SerializedPage& requestPage); + void invoke_function(const std::string& request, std::string& response) const; private: - proxygen::URL url_; - std::unordered_map headers_; + URL url_; std::shared_ptr httpClient_; };