Skip to content

Commit

Permalink
GH-37093: [FlightRPC][Python] Add async GetFlightInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Aug 9, 2023
1 parent 9f183fc commit 68f2aa0
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 0 deletions.
92 changes: 92 additions & 0 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,95 @@ cdef class FlightMetadataWriter(_Weakrefable):
check_flight_status(self.writer.get().WriteMetadata(deref(buf)))


class AsyncioCall:
"""State for an async RPC using asyncio."""

def __init__(self) -> None:
import asyncio

# Python waits on the event. The C++ callback sets the event,
# waking up the original task.
self._event = asyncio.Event()
# The result of the async call.
# TODO(lidavidm): how best to handle streams?
self._result = None
# The error raised by the async call.
self._exception = None
self._loop = asyncio.get_running_loop()

async def wait(self) -> object:
"""Wait for the RPC call to finish."""
await self._event.wait()
if self._exception:
raise self._exception
return self._result

def wakeup(self, *, exception=None, result=None) -> None:
"""Finish the RPC call."""
if exception:
self._exception = exception
else:
self._result = result
# Set the event from within the loop to avoid a race (asyncio
# objects are not necessarily thread-safe)
self._loop.call_soon_threadsafe(lambda: self._event.set())


cdef class AsyncioFlightClient:
"""
A FlightClient with an asyncio-based async interface.

This interface is EXPERIMENTAL.
"""

cdef:
FlightClient client

def __init__(self, FlightClient client) -> None:
self.client = client

async def get_flight_info(
self,
descriptor: FlightDescriptor,
*,
options: FlightCallOptions = None,
):
call = AsyncioCall()
self._get_flight_info(call, descriptor, options)
return await call.wait()

cdef _get_flight_info(self, call, descriptor, options):
cdef:
CFlightCallOptions* c_options = \
FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
function[cb_client_async_get_flight_info] callback = \
&_client_async_get_flight_info

with nogil:
CAsyncGetFlightInfo(
self.client.client.get(),
deref(c_options),
c_descriptor,
call,
callback,
)


cdef void _client_async_get_flight_info(void* self, CFlightInfo* info, const CStatus& status):
"""Bridge the C++ async call with the Python side."""
cdef FlightInfo result = FlightInfo.__new__(FlightInfo)
call: AsyncioCall = <object> self
try:
check_status(status)
except Exception as e:
call.wakeup(exception=e)
else:
result.info.reset(new CFlightInfo(move(deref(info))))
call.wakeup(result=result)


cdef class FlightClient(_Weakrefable):
"""A client to a Flight service.
Expand Down Expand Up @@ -1320,6 +1409,9 @@ cdef class FlightClient(_Weakrefable):
check_flight_status(CFlightClient.Connect(c_location, c_options
).Value(&self.client))

def as_async(self) -> None:
return AsyncioFlightClient(self)

def wait_for_available(self, timeout=5):
"""Block until the server can be contacted.
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ ctypedef CStatus cb_client_middleware_start_call(
const CCallInfo&,
unique_ptr[CClientMiddleware]*)

ctypedef void cb_client_async_get_flight_info(object, CFlightInfo* info, const CStatus& status)

cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef char* CPyServerMiddlewareName\
" arrow::py::flight::kPyServerMiddlewareName"
Expand Down Expand Up @@ -604,6 +606,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
shared_ptr[CSchema] schema,
unique_ptr[CSchemaResult]* out)

cdef void CAsyncGetFlightInfo" arrow::py::flight::AsyncGetFlightInfo"(CFlightClient*, const CFlightCallOptions&, const CFlightDescriptor&, object, function[cb_client_async_get_flight_info])


cdef extern from "<variant>" namespace "std" nogil:
cdef cppclass CIntStringVariant" std::variant<int, std::string>":
Expand Down
19 changes: 19 additions & 0 deletions python/pyarrow/src/arrow/python/flight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,25 @@ Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
return arrow::flight::SchemaResult::Make(*schema).Value(out);
}

void AsyncGetFlightInfo(arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback) {
OwnedRefNoGIL py_context(context);
auto future = client->GetFlightInfoAsync(options, descriptor);
future.AddCallback([callback, py_context = std::move(py_context)](
arrow::Result<arrow::flight::FlightInfo> result) {
std::ignore = SafeCallIntoPython([&] {
if (result.ok()) {
callback(py_context.obj(), &result.ValueOrDie(), result.status());
} else {
callback(py_context.obj(), nullptr, result.status());
}
return Status::OK();
});
});
}

} // namespace flight
} // namespace py
} // namespace arrow
10 changes: 10 additions & 0 deletions python/pyarrow/src/arrow/python/flight.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,16 @@ ARROW_PYFLIGHT_EXPORT
Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
std::unique_ptr<arrow::flight::SchemaResult>* out);

typedef std::function<void(PyObject* self, arrow::flight::FlightInfo* info,
const Status& status)>
AsyncGetFlightInfoCallback;

ARROW_PYFLIGHT_EXPORT
void AsyncGetFlightInfo(arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback);

} // namespace flight
} // namespace py
} // namespace arrow
70 changes: 70 additions & 0 deletions python/pyarrow/tests/test_flight_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import asyncio

import pytest

import pyarrow

flight = pytest.importorskip("pyarrow.flight")
pytestmark = pytest.mark.flight


class ExampleServer(flight.FlightServerBase):
simple_info = flight.FlightInfo(
pyarrow.schema([("a", "int32")]),
flight.FlightDescriptor.for_command(b"simple"),
[],
-1,
-1,
)

def get_flight_info(self, context, descriptor):
if descriptor.command == b"simple":
return self.simple_info
elif descriptor.command == b"unknown":
raise NotImplementedError("Unknown command")

raise NotImplementedError("Unknown descriptor")


@pytest.fixture(scope="module")
def async_client():
with ExampleServer() as server:
with flight.connect(f"grpc://localhost:{server.port}") as client:
yield client.as_async()


def test_get_flight_info(async_client):
async def _test():
descriptor = flight.FlightDescriptor.for_command(b"simple")
info = await async_client.get_flight_info(descriptor)
assert info == ExampleServer.simple_info

asyncio.run(_test())


def test_get_flight_info_error(async_client):
async def _test():
descriptor = flight.FlightDescriptor.for_command(b"unknown")
with pytest.raises(NotImplementedError) as excinfo:
await async_client.get_flight_info(descriptor)

assert "Unknown command" in repr(excinfo.value)

asyncio.run(_test())

0 comments on commit 68f2aa0

Please sign in to comment.