diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 0572ed77b40ef..3e2c60af6b5ee 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -1219,6 +1219,95 @@ cdef class FlightMetadataWriter(_Weakrefable): check_flight_status(self.writer.get().WriteMetadata(deref(buf))) +class AsyncioCall: + """State for an async RPC using asyncio.""" + + def __init__(self) -> None: + import asyncio + + # Python waits on the event. The C++ callback sets the event, + # waking up the original task. + self._event = asyncio.Event() + # The result of the async call. + # TODO(lidavidm): how best to handle streams? + self._result = None + # The error raised by the async call. + self._exception = None + self._loop = asyncio.get_running_loop() + + async def wait(self) -> object: + """Wait for the RPC call to finish.""" + await self._event.wait() + if self._exception: + raise self._exception + return self._result + + def wakeup(self, *, exception=None, result=None) -> None: + """Finish the RPC call.""" + if exception: + self._exception = exception + else: + self._result = result + # Set the event from within the loop to avoid a race (asyncio + # objects are not necessarily thread-safe) + self._loop.call_soon_threadsafe(lambda: self._event.set()) + + +cdef class AsyncioFlightClient: + """ + A FlightClient with an asyncio-based async interface. + + This interface is EXPERIMENTAL. + """ + + cdef: + FlightClient client + + def __init__(self, FlightClient client) -> None: + self.client = client + + async def get_flight_info( + self, + descriptor: FlightDescriptor, + *, + options: FlightCallOptions = None, + ): + call = AsyncioCall() + self._get_flight_info(call, descriptor, options) + return await call.wait() + + cdef _get_flight_info(self, call, descriptor, options): + cdef: + CFlightCallOptions* c_options = \ + FlightCallOptions.unwrap(options) + CFlightDescriptor c_descriptor = \ + FlightDescriptor.unwrap(descriptor) + function[cb_client_async_get_flight_info] callback = \ + &_client_async_get_flight_info + + with nogil: + CAsyncGetFlightInfo( + self.client.client.get(), + deref(c_options), + c_descriptor, + call, + callback, + ) + + +cdef void _client_async_get_flight_info(void* self, CFlightInfo* info, const CStatus& status): + """Bridge the C++ async call with the Python side.""" + cdef FlightInfo result = FlightInfo.__new__(FlightInfo) + call: AsyncioCall = self + try: + check_status(status) + except Exception as e: + call.wakeup(exception=e) + else: + result.info.reset(new CFlightInfo(move(deref(info)))) + call.wakeup(result=result) + + cdef class FlightClient(_Weakrefable): """A client to a Flight service. @@ -1320,6 +1409,9 @@ cdef class FlightClient(_Weakrefable): check_flight_status(CFlightClient.Connect(c_location, c_options ).Value(&self.client)) + def as_async(self) -> None: + return AsyncioFlightClient(self) + def wait_for_available(self, timeout=5): """Block until the server can be contacted. diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 624904ed77a69..ebe845f300237 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -496,6 +496,8 @@ ctypedef CStatus cb_client_middleware_start_call( const CCallInfo&, unique_ptr[CClientMiddleware]*) +ctypedef void cb_client_async_get_flight_info(object, CFlightInfo* info, const CStatus& status) + cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: cdef char* CPyServerMiddlewareName\ " arrow::py::flight::kPyServerMiddlewareName" @@ -604,6 +606,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: shared_ptr[CSchema] schema, unique_ptr[CSchemaResult]* out) + cdef void CAsyncGetFlightInfo" arrow::py::flight::AsyncGetFlightInfo"(CFlightClient*, const CFlightCallOptions&, const CFlightDescriptor&, object, function[cb_client_async_get_flight_info]) + cdef extern from "" namespace "std" nogil: cdef cppclass CIntStringVariant" std::variant": diff --git a/python/pyarrow/src/arrow/python/flight.cc b/python/pyarrow/src/arrow/python/flight.cc index bf7af27ac726e..3e0f5ffaaf16a 100644 --- a/python/pyarrow/src/arrow/python/flight.cc +++ b/python/pyarrow/src/arrow/python/flight.cc @@ -383,6 +383,25 @@ Status CreateSchemaResult(const std::shared_ptr& schema, return arrow::flight::SchemaResult::Make(*schema).Value(out); } +void AsyncGetFlightInfo(arrow::flight::FlightClient* client, + const arrow::flight::FlightCallOptions& options, + const arrow::flight::FlightDescriptor& descriptor, + PyObject* context, AsyncGetFlightInfoCallback callback) { + OwnedRefNoGIL py_context(context); + auto future = client->GetFlightInfoAsync(options, descriptor); + future.AddCallback([callback, py_context = std::move(py_context)]( + arrow::Result result) { + std::ignore = SafeCallIntoPython([&] { + if (result.ok()) { + callback(py_context.obj(), &result.ValueOrDie(), result.status()); + } else { + callback(py_context.obj(), nullptr, result.status()); + } + return Status::OK(); + }); + }); +} + } // namespace flight } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/flight.h b/python/pyarrow/src/arrow/python/flight.h index 82d93711e55fb..c8c460783b434 100644 --- a/python/pyarrow/src/arrow/python/flight.h +++ b/python/pyarrow/src/arrow/python/flight.h @@ -345,6 +345,16 @@ ARROW_PYFLIGHT_EXPORT Status CreateSchemaResult(const std::shared_ptr& schema, std::unique_ptr* out); +typedef std::function + AsyncGetFlightInfoCallback; + +ARROW_PYFLIGHT_EXPORT +void AsyncGetFlightInfo(arrow::flight::FlightClient* client, + const arrow::flight::FlightCallOptions& options, + const arrow::flight::FlightDescriptor& descriptor, + PyObject* context, AsyncGetFlightInfoCallback callback); + } // namespace flight } // namespace py } // namespace arrow diff --git a/python/pyarrow/tests/test_flight_async.py b/python/pyarrow/tests/test_flight_async.py new file mode 100644 index 0000000000000..11fd54155fd1a --- /dev/null +++ b/python/pyarrow/tests/test_flight_async.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio + +import pytest + +import pyarrow + +flight = pytest.importorskip("pyarrow.flight") +pytestmark = pytest.mark.flight + + +class ExampleServer(flight.FlightServerBase): + simple_info = flight.FlightInfo( + pyarrow.schema([("a", "int32")]), + flight.FlightDescriptor.for_command(b"simple"), + [], + -1, + -1, + ) + + def get_flight_info(self, context, descriptor): + if descriptor.command == b"simple": + return self.simple_info + elif descriptor.command == b"unknown": + raise NotImplementedError("Unknown command") + + raise NotImplementedError("Unknown descriptor") + + +@pytest.fixture(scope="module") +def async_client(): + with ExampleServer() as server: + with flight.connect(f"grpc://localhost:{server.port}") as client: + yield client.as_async() + + +def test_get_flight_info(async_client): + async def _test(): + descriptor = flight.FlightDescriptor.for_command(b"simple") + info = await async_client.get_flight_info(descriptor) + assert info == ExampleServer.simple_info + + asyncio.run(_test()) + + +def test_get_flight_info_error(async_client): + async def _test(): + descriptor = flight.FlightDescriptor.for_command(b"unknown") + with pytest.raises(NotImplementedError) as excinfo: + await async_client.get_flight_info(descriptor) + + assert "Unknown command" in repr(excinfo.value) + + asyncio.run(_test())