Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(functions): Add support for REST based remote functions #10911

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions velox/functions/remote/client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp)
velox_link_libraries(velox_functions_remote_thrift_client
PUBLIC remote_function_thrift FBThrift::thriftcpp2)

set(curl_SOURCE BUNDLED)
velox_resolve_dependency(curl)

velox_add_library(velox_functions_remote_rest_client RestClient.cpp)
velox_link_libraries(velox_functions_remote_rest_client Folly::folly
${CURL_LIBRARIES})

velox_add_library(velox_functions_remote Remote.cpp)
velox_link_libraries(
velox_functions_remote
PUBLIC velox_expression
velox_memory
velox_exec
velox_vector
velox_presto_serializer
velox_functions_remote_thrift_client
velox_functions_remote_rest_client
velox_functions_remote_get_serde
velox_type_fbhive
Folly::folly)
Expand Down
124 changes: 105 additions & 19 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,70 @@

#include "velox/functions/remote/client/Remote.h"

#include <fmt/format.h>
#include <folly/io/async/EventBase.h>
#include <sstream>
#include <string>

#include "velox/common/memory/ByteStream.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add empty line between the system and velox includes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new line

#include "velox/expression/Expr.h"
#include "velox/expression/VectorFunction.h"
#include "velox/functions/remote/client/RestClient.h"
#include "velox/functions/remote/client/ThriftClient.h"
#include "velox/functions/remote/if/GetSerde.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h"
#include "velox/serializers/PrestoSerializer.h"
#include "velox/type/fbhive/HiveTypeSerializer.h"
#include "velox/vector/VectorStream.h"

using namespace folly;
namespace facebook::velox::functions {
namespace {

std::string serializeType(const TypePtr& type) {
// Use hive type serializer.
return type::fbhive::HiveTypeSerializer::serialize(type);
}

std::string extractFunctionName(const std::string& input) {
size_t lastDot = input.find_last_of('.');
if (lastDot != std::string::npos) {
return input.substr(lastDot + 1);
}
return input;
}

std::string urlEncode(const std::string& value) {
std::ostringstream escaped;
escaped.fill('0');
escaped << std::hex;
for (char c : value) {
if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' ||
c == '.' || c == '~') {
escaped << c;
} else {
escaped << '%' << std::setw(2) << int(static_cast<unsigned char>(c));
}
}
return escaped.str();
}

class RemoteFunction : public exec::VectorFunction {
public:
RemoteFunction(
const std::string& functionName,
const std::vector<exec::VectorFunctionArg>& inputArgs,
const RemoteVectorFunctionMetadata& metadata)
const RemoteVectorFunctionMetadata& metadata,
std::unique_ptr<HttpClient> httpClient = nullptr)
: functionName_(functionName),
location_(metadata.location),
thriftClient_(getThriftClient(location_, &eventBase_)),
serdeFormat_(metadata.serdeFormat),
serde_(getSerde(serdeFormat_)) {
restClient_(httpClient ? std::move(httpClient) : getRestClient()),
metadata_(metadata) {
if (metadata.location.type() == typeid(SocketAddress)) {
location_ = boost::get<SocketAddress>(metadata.location);
thriftClient_ = getThriftClient(location_, &eventBase_);
} else if (metadata.location.type() == typeid(std::string)) {
url_ = boost::get<std::string>(metadata.location);
}

std::vector<TypePtr> types;
types.reserve(inputArgs.size());
serializedInputTypes_.reserve(inputArgs.size());
Expand All @@ -62,7 +98,11 @@ class RemoteFunction : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& result) const override {
try {
applyRemote(rows, args, outputType, context, result);
if ((metadata_.location.type() == typeid(SocketAddress))) {
applyRemote(rows, args, outputType, context, result);
} else if (metadata_.location.type() == typeid(std::string)) {
applyRestRemote(rows, args, outputType, context, result);
}
} catch (const VeloxRuntimeError&) {
throw;
} catch (const std::exception&) {
Expand All @@ -71,6 +111,48 @@ class RemoteFunction : public exec::VectorFunction {
}

private:
void applyRestRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
try {
serializer::presto::PrestoVectorSerde serde;
auto remoteRowVector = std::make_shared<RowVector>(
context.pool(),
remoteInputType_,
BufferPtr{},
rows.end(),
std::move(args));

std::unique_ptr<IOBuf> requestBody =
std::make_unique<IOBuf>(rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), &serde));

const std::string fullUrl = fmt::format(
"{}/v1/functions/{}/{}/{}/{}",
url_,
metadata_.schema.value_or("default"),
extractFunctionName(functionName_),
urlEncode(metadata_.functionId.value_or("default_function_id")),
metadata_.version.value_or("1"));

std::unique_ptr<IOBuf> responseBody =
restClient_->invokeFunction(fullUrl, std::move(requestBody));

auto outputRowVector = IOBufToRowVector(
*responseBody, ROW({outputType}), *context.pool(), &serde);

result = outputRowVector->childAt(0);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}': {}",
functionName_,
e.what());
}
}

void applyRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
Expand All @@ -97,11 +179,14 @@ class RemoteFunction : public exec::VectorFunction {

auto requestInputs = request.inputs_ref();
requestInputs->rowCount_ref() = remoteRowVector->size();
requestInputs->pageFormat_ref() = serdeFormat_;
requestInputs->pageFormat_ref() = metadata_.serdeFormat;

// TODO: serialize only active rows.
requestInputs->payload_ref() = rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), serde_.get());
remoteRowVector,
rows.end(),
*context.pool(),
getSerde(metadata_.serdeFormat).get());

try {
thriftClient_->sync_invokeFunction(remoteResponse, request);
Expand All @@ -117,12 +202,15 @@ class RemoteFunction : public exec::VectorFunction {
remoteResponse.get_result().get_payload(),
ROW({outputType}),
*context.pool(),
serde_.get());
getSerde(metadata_.serdeFormat).get());
result = outputRowVector->childAt(0);

if (auto errorPayload = remoteResponse.get_result().errorPayload()) {
auto errorsRowVector = IOBufToRowVector(
*errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get());
*errorPayload,
ROW({VARCHAR()}),
*context.pool(),
getSerde(metadata_.serdeFormat).get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<StringView>();
VELOX_CHECK(errorsVector, "Should be convertible to flat vector");
Expand All @@ -142,16 +230,14 @@ class RemoteFunction : public exec::VectorFunction {
}

const std::string functionName_;
folly::SocketAddress location_;

folly::EventBase eventBase_;
EventBase eventBase_;
std::unique_ptr<RemoteFunctionClient> thriftClient_;
remote::PageFormat serdeFormat_;
std::unique_ptr<VectorSerde> serde_;

// Structures we construct once to cache:
std::unique_ptr<HttpClient> restClient_;
SocketAddress location_;
std::string url_;
RowTypePtr remoteInputType_;
std::vector<std::string> serializedInputTypes_;
const RemoteVectorFunctionMetadata metadata_;
};

std::shared_ptr<exec::VectorFunction> createRemoteFunction(
Expand All @@ -169,7 +255,7 @@ void registerRemoteFunction(
std::vector<exec::FunctionSignaturePtr> signatures,
const RemoteVectorFunctionMetadata& metadata,
bool overwrite) {
exec::registerStatefulVectorFunction(
registerStatefulVectorFunction(
name,
signatures,
std::bind(
Expand Down
25 changes: 21 additions & 4 deletions velox/functions/remote/client/Remote.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,37 @@

#pragma once

#include <boost/variant.hpp>
#include <folly/SocketAddress.h>
#include "velox/expression/VectorFunction.h"
#include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h"

namespace facebook::velox::functions {

struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata {
/// Network address of the servr to communicate with. Note that this can hold
/// a network location (ip/port pair) or a unix domain socket path (see
/// URL of the HTTP/REST server for remote function.
/// Or Network address of the server to communicate with. Note that this can
/// hold a network location (ip/port pair) or a unix domain socket path (see
/// SocketAddress::makeFromPath()).
folly::SocketAddress location;
boost::variant<folly::SocketAddress, std::string> location;

/// The serialization format to be used
/// The serialization format to be used when sending data to the remote.
remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE};

/// Optional schema defining the structure of the data or input/output types
/// involved in the remote function. This may include details such as column
/// names and data types.
std::optional<std::string> schema;

/// Optional identifier for the specific remote function to be invoked.
/// This can be useful when the same server hosts multiple functions,
/// and the client needs to specify which function to call.
std::optional<std::string> functionId;

/// Optional version information to be used when calling the remote function.
/// This can help in ensuring compatibility with a particular version of the
/// function if multiple versions are available on the server.
std::optional<std::string> version;
};

/// Registers a new remote function. It will use the meatadata defined in
Expand Down
128 changes: 128 additions & 0 deletions velox/functions/remote/client/RestClient.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/functions/remote/client/RestClient.h"

#include <curl/curl.h>
#include <folly/io/IOBufQueue.h>

#include "velox/common/base/Exceptions.h"

using namespace folly;
namespace facebook::velox::functions {
namespace {

// Callback function for CURL to read data from the request payload.
// @param dest Destination buffer to copy data into.
// @param size Size of each data element.
// @param nmemb Number of elements to read.
// @param userp Pointer to user data (IOBufQueue containing the request
// payload).
// @return Number of bytes actually copied.
size_t readCallback(char* dest, size_t size, size_t nmemb, void* userp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write comments explaining the signature and the parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the documentation

auto* inputBufQueue = static_cast<IOBufQueue*>(userp);
size_t bufferSize = size * nmemb;
size_t totalCopied = 0;

while (totalCopied < bufferSize && !inputBufQueue->empty()) {
auto buf = inputBufQueue->front();
size_t remainingSize = bufferSize - totalCopied;
size_t copySize = std::min(remainingSize, buf->length());
std::memcpy(dest + totalCopied, buf->data(), copySize);
totalCopied += copySize;
inputBufQueue->pop_front();
}

return totalCopied;
}

// Callback function for CURL to write data to the response payload.
// @param ptr Pointer to the received data.
// @param size Size of each data element.
// @param nmemb Number of elements received.
// @param userData Pointer to user data (IOBufQueue to store the response
// payload).
// @return Number of bytes actually written.
size_t writeCallback(char* ptr, size_t size, size_t nmemb, void* userData) {
auto* outputBuf = static_cast<IOBufQueue*>(userData);
size_t totalSize = size * nmemb;
auto buf = IOBuf::copyBuffer(ptr, totalSize);
outputBuf->append(std::move(buf));
return totalSize;
}
} // namespace

std::unique_ptr<IOBuf> RestClient::invokeFunction(
const std::string& fullUrl,
std::unique_ptr<IOBuf> requestPayload) {
try {
IOBufQueue inputBufQueue(IOBufQueue::cacheChainLength());
inputBufQueue.append(std::move(requestPayload));

CURL* curl = curl_easy_init();
if (!curl) {
VELOX_FAIL(fmt::format(
"Error initializing CURL: {}",
curl_easy_strerror(CURLE_FAILED_INIT)));
}

curl_easy_setopt(curl, CURLOPT_URL, fullUrl.c_str());
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_READFUNCTION, readCallback);
curl_easy_setopt(curl, CURLOPT_READDATA, &inputBufQueue);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writeCallback);

IOBufQueue outputBuf(IOBufQueue::cacheChainLength());
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &outputBuf);
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);

struct curl_slist* headers = nullptr;
headers =
curl_slist_append(headers, "Content-Type: application/X-presto-pages");
headers = curl_slist_append(headers, "Accept: application/X-presto-pages");
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);

curl_easy_setopt(
curl,
CURLOPT_POSTFIELDSIZE,
static_cast<long>(inputBufQueue.chainLength()));

CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
curl_slist_free_all(headers);
curl_easy_cleanup(curl);
VELOX_FAIL(fmt::format(
"Error communicating with server: {}\nURL: {}\nCURL Error: {}",
curl_easy_strerror(res),
fullUrl.c_str(),
curl_easy_strerror(res)));
}

curl_slist_free_all(headers);
curl_easy_cleanup(curl);

return outputBuf.move();

} catch (const std::exception& e) {
VELOX_FAIL(fmt::format("Exception during CURL request: {}", e.what()));
}
}

std::unique_ptr<HttpClient> getRestClient() {
return std::make_unique<RestClient>();
}

} // namespace facebook::velox::functions
Loading
Loading