Skip to content

Commit

Permalink
apacheGH-43631: [C++] Add C++ implementation of Async C Data Interface (
Browse files Browse the repository at this point in the history
apache#44495)

### Rationale for this change
Building on apache#43632 which created the Async C Data Structures, this adds functions to `bridge.h`/`bridge.cc` to implement helpers for managing the Async C Data interfaces 

### What changes are included in this PR?
Two functions added to bridge.h:

1. `CreateAsyncDeviceStreamHandler` populates a `ArrowAsyncDeviceStreamHandler` and an `Executor` to provide a future that resolves to an `AsyncRecordBatchGenerator` to produce record batches as they are pushed asynchronously. The `ArrowAsyncDeviceStreamHandler` can then be passed to any asynchronous producer.
2. `ExportAsyncRecordBatchReader` takes a record batch generator and a schema, along with an `ArrowAsyncDeviceStreamHandler` to use for calling the callbacks to push data as it is available from the generator. 

### Are these changes tested?
Unit tests are added (currently only one test, more tests to be added)

### Are there any user-facing changes?
No

* GitHub Issue: apache#43631

Lead-authored-by: Matt Topol <[email protected]>
Co-authored-by: David Li <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Matt Topol <[email protected]>
  • Loading branch information
3 people authored Nov 11, 2024
1 parent d7e982c commit aab7d81
Show file tree
Hide file tree
Showing 7 changed files with 558 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cpp/src/arrow/c/abi.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ struct ArrowAsyncTask {
// calling this, and so it must be released separately.
//
// It is only valid to call this method exactly once.
int (*extract_data)(struct ArrowArrayTask* self, struct ArrowDeviceArray* out);
int (*extract_data)(struct ArrowAsyncTask* self, struct ArrowDeviceArray* out);

// opaque task-specific data
void* private_data;
Expand All @@ -298,6 +298,9 @@ struct ArrowAsyncTask {
// control on the asynchronous stream processing. This object must be owned by the
// producer who creates it, and thus is responsible for cleaning it up.
struct ArrowAsyncProducer {
// The device type that this stream produces data on.
ArrowDeviceType device_type;

// A consumer must call this function to start receiving on_next_task calls.
//
// It *must* be valid to call this synchronously from within `on_next_task` or
Expand Down
348 changes: 348 additions & 0 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

#include <algorithm>
#include <cerrno>
#include <condition_variable>
#include <cstring>
#include <memory>
#include <mutex>
#include <queue>
#include <string>
#include <string_view>
#include <utility>
Expand All @@ -37,8 +40,10 @@
#include "arrow/result.h"
#include "arrow/stl_allocator.h"
#include "arrow/type_traits.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
Expand Down Expand Up @@ -2511,4 +2516,347 @@ Result<std::shared_ptr<ChunkedArray>> ImportDeviceChunkedArray(
return ImportChunked</*IsDevice=*/true>(stream, mapper);
}

namespace {

class AsyncRecordBatchIterator {
public:
struct TaskWithMetadata {
ArrowAsyncTask task_;
std::shared_ptr<KeyValueMetadata> metadata_;
};

struct State {
State(uint64_t queue_size, DeviceMemoryMapper mapper)
: queue_size_{queue_size}, mapper_{std::move(mapper)} {}

Result<RecordBatchWithMetadata> next() {
TaskWithMetadata task;
{
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock,
[&] { return !error_.ok() || !batches_.empty() || end_of_stream_; });
if (!error_.ok()) {
return error_;
}

if (batches_.empty() && end_of_stream_) {
return IterationEnd<RecordBatchWithMetadata>();
}

task = std::move(batches_.front());
batches_.pop();
}

producer_->request(producer_, 1);
ArrowDeviceArray out;
if (task.task_.extract_data(&task.task_, &out) != 0) {
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return !error_.ok(); });
return error_;
}

ARROW_ASSIGN_OR_RAISE(auto batch, ImportDeviceRecordBatch(&out, schema_, mapper_));
return RecordBatchWithMetadata{std::move(batch), std::move(task.metadata_)};
}

const uint64_t queue_size_;
const DeviceMemoryMapper mapper_;
ArrowAsyncProducer* producer_;
DeviceAllocationType device_type_;

std::mutex mutex_;
std::shared_ptr<Schema> schema_;
std::condition_variable cv_;
std::queue<TaskWithMetadata> batches_;
bool end_of_stream_ = false;
Status error_{Status::OK()};
};

AsyncRecordBatchIterator(uint64_t queue_size, DeviceMemoryMapper mapper)
: state_{std::make_shared<State>(queue_size, std::move(mapper))} {}

explicit AsyncRecordBatchIterator(std::shared_ptr<State> state)
: state_{std::move(state)} {}

const std::shared_ptr<Schema>& schema() const { return state_->schema_; }

DeviceAllocationType device_type() const { return state_->device_type_; }

Result<RecordBatchWithMetadata> Next() { return state_->next(); }

static Future<std::shared_ptr<AsyncRecordBatchIterator::State>> Make(
AsyncRecordBatchIterator& iterator, struct ArrowAsyncDeviceStreamHandler* handler) {
auto iterator_fut = Future<std::shared_ptr<AsyncRecordBatchIterator::State>>::Make();

auto private_data = new PrivateData{iterator.state_};
private_data->fut_iterator_ = iterator_fut;

handler->private_data = private_data;
handler->on_schema = on_schema;
handler->on_next_task = on_next_task;
handler->on_error = on_error;
handler->release = release;
return iterator_fut;
}

private:
struct PrivateData {
explicit PrivateData(std::shared_ptr<State> state) : state_(std::move(state)) {}

std::shared_ptr<State> state_;
Future<std::shared_ptr<AsyncRecordBatchIterator::State>> fut_iterator_;
ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateData);
};

static int on_schema(struct ArrowAsyncDeviceStreamHandler* self,
struct ArrowSchema* stream_schema) {
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
if (self->producer != nullptr) {
private_data->state_->producer_ = self->producer;
private_data->state_->device_type_ =
static_cast<DeviceAllocationType>(self->producer->device_type);
}

auto maybe_schema = ImportSchema(stream_schema);
if (!maybe_schema.ok()) {
private_data->fut_iterator_.MarkFinished(maybe_schema.status());
return EINVAL;
}

private_data->state_->schema_ = maybe_schema.MoveValueUnsafe();
private_data->fut_iterator_.MarkFinished(private_data->state_);
self->producer->request(self->producer,
static_cast<int64_t>(private_data->state_->queue_size_));
return 0;
}

static int on_next_task(ArrowAsyncDeviceStreamHandler* self, ArrowAsyncTask* task,
const char* metadata) {
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);

if (task == nullptr) {
std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
private_data->state_->end_of_stream_ = true;
lock.unlock();
private_data->state_->cv_.notify_one();
return 0;
}

std::shared_ptr<KeyValueMetadata> kvmetadata;
if (metadata != nullptr) {
auto maybe_decoded = DecodeMetadata(metadata);
if (!maybe_decoded.ok()) {
private_data->state_->error_ = std::move(maybe_decoded).status();
private_data->state_->cv_.notify_one();
return EINVAL;
}

kvmetadata = std::move(maybe_decoded->metadata);
}

std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
private_data->state_->batches_.push({*task, std::move(kvmetadata)});
lock.unlock();
private_data->state_->cv_.notify_one();
return 0;
}

static void on_error(ArrowAsyncDeviceStreamHandler* self, int code, const char* message,
const char* metadata) {
auto* private_data = reinterpret_cast<PrivateData*>(self->private_data);
std::string message_str, metadata_str;
if (message != nullptr) {
message_str = message;
}
if (metadata != nullptr) {
metadata_str = metadata;
}

Status error = Status::FromDetailAndArgs(
StatusCode::UnknownError,
std::make_shared<AsyncErrorDetail>(code, message_str, std::move(metadata_str)),
std::move(message_str));

if (!private_data->fut_iterator_.is_finished()) {
private_data->fut_iterator_.MarkFinished(error);
return;
}

std::unique_lock<std::mutex> lock(private_data->state_->mutex_);
private_data->state_->error_ = std::move(error);
lock.unlock();
private_data->state_->cv_.notify_one();
}

static void release(ArrowAsyncDeviceStreamHandler* self) {
delete reinterpret_cast<PrivateData*>(self->private_data);
}

std::shared_ptr<State> state_;
};

struct AsyncProducer {
struct State {
struct ArrowAsyncProducer producer_;

std::mutex mutex_;
std::condition_variable cv_;
uint64_t pending_requests_{0};
Status error_{Status::OK()};
};

AsyncProducer(DeviceAllocationType device_type, struct ArrowSchema* schema,
struct ArrowAsyncDeviceStreamHandler* handler)
: handler_{handler}, state_{std::make_shared<State>()} {
state_->producer_.device_type = static_cast<ArrowDeviceType>(device_type);
state_->producer_.private_data = reinterpret_cast<void*>(state_.get());
state_->producer_.request = AsyncProducer::request;
state_->producer_.cancel = AsyncProducer::cancel;
handler_->producer = &state_->producer_;

if (int status = handler_->on_schema(handler_, schema) != 0) {
state_->error_ =
Status::UnknownError("Received error from handler::on_schema ", status);
}
}

struct PrivateTaskData {
PrivateTaskData(std::shared_ptr<State> producer, std::shared_ptr<RecordBatch> record)
: producer_{std::move(producer)}, record_(std::move(record)) {}

std::shared_ptr<State> producer_;
std::shared_ptr<RecordBatch> record_;
ARROW_DISALLOW_COPY_AND_ASSIGN(PrivateTaskData);
};

Status operator()(const std::shared_ptr<RecordBatch>& record) {
std::unique_lock<std::mutex> lock(state_->mutex_);
if (state_->pending_requests_ == 0) {
state_->cv_.wait(lock, [this]() -> bool {
return !state_->error_.ok() || state_->pending_requests_ > 0;
});
}

if (!state_->error_.ok()) {
return state_->error_;
}

if (state_->pending_requests_ > 0) {
state_->pending_requests_--;
lock.unlock();

ArrowAsyncTask task;
task.private_data = new PrivateTaskData{state_, record};
task.extract_data = AsyncProducer::extract_data;

if (int status = handler_->on_next_task(handler_, &task, nullptr) != 0) {
delete reinterpret_cast<PrivateTaskData*>(task.private_data);
return Status::UnknownError("Received error from handler::on_next_task ", status);
}
}

return Status::OK();
}

static void request(struct ArrowAsyncProducer* producer, int64_t n) {
auto* self = reinterpret_cast<State*>(producer->private_data);
{
std::lock_guard<std::mutex> lock(self->mutex_);
if (!self->error_.ok()) {
return;
}
self->pending_requests_ += n;
}
self->cv_.notify_all();
}

static void cancel(struct ArrowAsyncProducer* producer) {
auto* self = reinterpret_cast<State*>(producer->private_data);
{
std::lock_guard<std::mutex> lock(self->mutex_);
if (!self->error_.ok()) {
return;
}
self->error_ = Status::Cancelled("Consumer requested cancellation");
}
self->cv_.notify_all();
}

static int extract_data(struct ArrowAsyncTask* task, struct ArrowDeviceArray* out) {
std::unique_ptr<PrivateTaskData> private_data{
reinterpret_cast<PrivateTaskData*>(task->private_data)};
int ret = 0;
if (out != nullptr) {
auto status = ExportDeviceRecordBatch(*private_data->record_,
private_data->record_->GetSyncEvent(), out);
if (!status.ok()) {
std::lock_guard<std::mutex> lock(private_data->producer_->mutex_);
private_data->producer_->error_ = status;
}
}

return ret;
}

struct ArrowAsyncDeviceStreamHandler* handler_;
std::shared_ptr<State> state_;
};

} // namespace

Future<AsyncRecordBatchGenerator> CreateAsyncDeviceStreamHandler(
struct ArrowAsyncDeviceStreamHandler* handler, internal::Executor* executor,
uint64_t queue_size, DeviceMemoryMapper mapper) {
auto iterator =
std::make_shared<AsyncRecordBatchIterator>(queue_size, std::move(mapper));
return AsyncRecordBatchIterator::Make(*iterator, handler)
.Then([executor](std::shared_ptr<AsyncRecordBatchIterator::State> state)
-> Result<AsyncRecordBatchGenerator> {
AsyncRecordBatchGenerator gen{state->schema_, state->device_type_, nullptr};
auto it =
Iterator<RecordBatchWithMetadata>(AsyncRecordBatchIterator{std::move(state)});
ARROW_ASSIGN_OR_RAISE(gen.generator,
MakeBackgroundGenerator(std::move(it), executor));
return gen;
});
}

Future<> ExportAsyncRecordBatchReader(
std::shared_ptr<Schema> schema,
AsyncGenerator<std::shared_ptr<RecordBatch>> generator,
DeviceAllocationType device_type, struct ArrowAsyncDeviceStreamHandler* handler) {
if (!schema) {
handler->on_error(handler, EINVAL, "Schema is null", nullptr);
handler->release(handler);
return Future<>::MakeFinished(Status::Invalid("Schema is null"));
}

struct ArrowSchema c_schema;
SchemaExportGuard guard(&c_schema);

auto status = ExportSchema(*schema, &c_schema);
if (!status.ok()) {
handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
handler->release(handler);
return Future<>::MakeFinished(status);
}

return VisitAsyncGenerator(generator, AsyncProducer{device_type, &c_schema, handler})
.Then(
[handler]() -> Status {
int status = handler->on_next_task(handler, nullptr, nullptr);
handler->release(handler);
if (status != 0) {
return Status::UnknownError("Received error from handler::on_next_task ",
status);
}
return Status::OK();
},
[handler](const Status status) -> Status {
handler->on_error(handler, EINVAL, status.message().c_str(), nullptr);
handler->release(handler);
return status;
});
}

} // namespace arrow
Loading

0 comments on commit aab7d81

Please sign in to comment.