Skip to content

Commit

Permalink
Bind to future instead of dealing with callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou authored and lidavidm committed Aug 18, 2023
1 parent 3e05f9d commit 0bd8c44
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 63 deletions.
39 changes: 15 additions & 24 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,12 @@ cdef class FlightInfo(_Weakrefable):
cdef:
unique_ptr[CFlightInfo] info

@staticmethod
cdef _wrap_unsafe(void* c_info):
cdef FlightInfo obj = FlightInfo.__new__(FlightInfo)
obj.info.reset(new CFlightInfo(move(deref(<CFlightInfo*> c_info))))
return obj

def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints,
total_records, total_bytes):
"""Create a FlightInfo object from a schema, descriptor, and endpoints.
Expand Down Expand Up @@ -1241,10 +1247,11 @@ class AsyncioCall:
raise self._exception
return self._result

def wakeup(self, *, result=None, exception=None) -> None:
"""Finish the RPC call."""
self._result = result
self._exception = exception
def wakeup(self, result_or_exception) -> None:
if isinstance(result_or_exception, BaseException):
self._exception = result_or_exception
else:
self._result = result_or_exception
# Set the event from within the loop to avoid a race (asyncio
# objects are not necessarily thread-safe)
self._loop.call_soon_threadsafe(lambda: self._event.set())
Expand Down Expand Up @@ -1279,29 +1286,13 @@ cdef class AsyncioFlightClient:
FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
function[cb_client_async_get_flight_info] callback = \
&_client_async_get_flight_info
CFuture[CFlightInfo] c_future

with nogil:
CAsyncGetFlightInfo(
self._client.client.get(),
deref(c_options),
c_descriptor,
call,
callback,
)

c_future = self._client.client.get().GetFlightInfoAsync(
deref(c_options), c_descriptor)

cdef void _client_async_get_flight_info(void* self, CFlightInfo* info, const CStatus& status) except *:
"""Bridge the C++ async call with the Python side."""
cdef:
FlightInfo result = FlightInfo.__new__(FlightInfo)
call: AsyncioCall = <object> self
if status.ok():
result.info.reset(new CFlightInfo(move(deref(info))))
call.wakeup(result=result)
else:
call.wakeup(exception=convert_status(status))
BindFuture(move(c_future), call.wakeup, FlightInfo._wrap_unsafe)


cdef class FlightClient(_Weakrefable):
Expand Down
15 changes: 12 additions & 3 deletions python/pyarrow/error.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class ArrowCancelled(ArrowException):
ArrowIOError = IOError


# This function could be written directly in C++ if we didn't
# define Arrow-specific subclasses (ArrowInvalid etc.)
# check_status() and convert_status() could be written directly in C++
# if we didn't define Arrow-specific subclasses (ArrowInvalid etc.)
cdef int check_status(const CStatus& status) except -1 nogil:
if status.ok():
return 0
Expand All @@ -92,6 +92,12 @@ cdef int check_status(const CStatus& status) except -1 nogil:


cdef object convert_status(const CStatus& status):
if IsPyError(status):
try:
RestorePyError(status)
except BaseException as e:
return e

# We don't use Status::ToString() as it would redundantly include
# the C++ class name.
message = frombytes(status.message(), safe=True)
Expand Down Expand Up @@ -142,11 +148,14 @@ cdef object convert_status(const CStatus& status):
return ArrowException(message)


# This is an API function for C++ PyArrow
# These are API functions for C++ PyArrow
cdef api int pyarrow_internal_check_status(const CStatus& status) \
except -1 nogil:
return check_status(status)

cdef api object pyarrow_internal_convert_status(const CStatus& status):
return convert_status(status)


cdef class StopToken:
cdef void init(self, CStopToken stop_token):
Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/includes/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,20 @@ cdef extern from "arrow/result.h" namespace "arrow" nogil:
T operator*()


cdef extern from "arrow/util/future.h" namespace "arrow" nogil:
cdef cppclass CFuture "arrow::Future"[T]:
CFuture()


ctypedef object PyWrapper(void*)


cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil:
T GetResultValue[T](CResult[T]) except *
cdef function[F] BindFunction[F](void* unbound, object bound, ...)

void BindFuture[T](CFuture[T], object cb, PyWrapper wrapper)


cdef inline object PyObject_to_object(PyObject* o):
# Cast to "object" increments reference count
Expand Down
6 changes: 2 additions & 4 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CResult[unique_ptr[CFlightListing]] ListFlights(CFlightCallOptions& options, CCriteria criteria)
CResult[unique_ptr[CFlightInfo]] GetFlightInfo(CFlightCallOptions& options,
CFlightDescriptor& descriptor)
CFuture[CFlightInfo] GetFlightInfoAsync(CFlightCallOptions& options,
CFlightDescriptor& descriptor)
CResult[unique_ptr[CSchemaResult]] GetSchema(CFlightCallOptions& options,
CFlightDescriptor& descriptor)
CResult[unique_ptr[CFlightStreamReader]] DoGet(CFlightCallOptions& options, CTicket& ticket)
Expand Down Expand Up @@ -496,8 +498,6 @@ ctypedef CStatus cb_client_middleware_start_call(
const CCallInfo&,
unique_ptr[CClientMiddleware]*)

ctypedef void cb_client_async_get_flight_info(object, CFlightInfo* info, const CStatus& status)

cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
cdef char* CPyServerMiddlewareName\
" arrow::py::flight::kPyServerMiddlewareName"
Expand Down Expand Up @@ -606,8 +606,6 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
shared_ptr[CSchema] schema,
unique_ptr[CSchemaResult]* out)

cdef void CAsyncGetFlightInfo" arrow::py::flight::AsyncGetFlightInfo"(CFlightClient*, const CFlightCallOptions&, const CFlightDescriptor&, object, function[cb_client_async_get_flight_info])


cdef extern from "<variant>" namespace "std" nogil:
cdef cppclass CIntStringVariant" std::variant<int, std::string>":
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/includes/libarrow_python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ cdef extern from "arrow/python/pyarrow.h" namespace "arrow::py":

cdef extern from "arrow/python/common.h" namespace "arrow::py":
c_bool IsPyError(const CStatus& status)
void RestorePyError(const CStatus& status)
void RestorePyError(const CStatus& status) except *


cdef extern from "arrow/python/inference.h" namespace "arrow::py":
Expand Down
54 changes: 54 additions & 0 deletions python/pyarrow/src/arrow/python/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "arrow/python/pyarrow.h"
#include "arrow/python/visibility.h"
#include "arrow/result.h"
#include "arrow/util/future.h"
#include "arrow/util/macros.h"

namespace arrow {
Expand Down Expand Up @@ -72,6 +73,29 @@ T GetResultValue(Result<T> result) {
}
}

// Wrap a Result<T> and return the corresponding Python object.
// * If the Result<T> is successful, PyWrapper(&value) is called,
// which should return a PyObject*.
// * If the Result<T> is an error, the corresponding Python exception
// is returned.
template <typename T, typename PyWrapper = PyObject* (*)(void*)>
PyObject* WrapResult(Result<T> result, PyWrapper py_wrapper) {
static_assert(std::is_same_v<PyObject*, decltype(py_wrapper(std::declval<T*>()))>,
"PyWrapper argument to WrapResult should return a PyObject* "
"when called with a T*");
Status st = result.status();
if (st.ok()) {
PyObject* py_value = py_wrapper(&result.ValueUnsafe());
st = CheckPyError();
if (st.ok()) {
return py_value;
}
Py_XDECREF(py_value); // should be null, but who knows
}
// Status is an error, convert it to an exception.
return internal::convert_status(st);
}

// A RAII-style helper that ensures the GIL is acquired inside a lexical block.
class ARROW_PYTHON_EXPORT PyAcquireGIL {
public:
Expand Down Expand Up @@ -131,6 +155,19 @@ auto SafeCallIntoPython(Function&& func) -> decltype(func()) {
return maybe_status;
}

template <typename Function>
auto SafeCallIntoPythonVoid(Function&& func) -> decltype(func()) {
PyAcquireGIL lock;
PyObject* exc_type;
PyObject* exc_value;
PyObject* exc_traceback;
PyErr_Fetch(&exc_type, &exc_value, &exc_traceback);
func();
if (exc_type != NULLPTR) {
PyErr_Restore(exc_type, exc_value, exc_traceback);
}
}

// A RAII primitive that DECREFs the underlying PyObject* when it
// goes out of scope.
class ARROW_PYTHON_EXPORT OwnedRef {
Expand Down Expand Up @@ -251,6 +288,23 @@ std::function<OutFn> BindFunction(Return (*unbound)(PyObject*, Args...),
[bound_fn](Args... args) { return bound_fn->Invoke(std::forward<Args>(args)...); };
}

// XXX Put this in arrow/python/async.h to avoid adding more random stuff here?
template <typename T, typename Wrapper = PyObject* (*)(void*)>
void BindFuture(Future<T> future, PyObject* py_cb, Wrapper py_wrapper) {
Py_INCREF(py_cb);
OwnedRefNoGIL cb_ref(py_cb);

auto future_cb = [cb_ref = std::move(cb_ref), py_wrapper](Result<T> result) {
SafeCallIntoPythonVoid([&]() {
OwnedRef py_value_or_exc{WrapResult(std::move(result), std::move(py_wrapper))};
Py_XDECREF(
PyObject_CallFunctionObjArgs(cb_ref.obj(), py_value_or_exc.obj(), NULLPTR));
ARROW_WARN_NOT_OK(CheckPyError(), "Internal error in async call");
});
};
future.AddCallback(std::move(future_cb));
}

// A temporary conversion of a Python object to a bytes area.
struct PyBytesView {
const char* bytes;
Expand Down
21 changes: 0 additions & 21 deletions python/pyarrow/src/arrow/python/flight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,27 +383,6 @@ Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
return arrow::flight::SchemaResult::Make(*schema).Value(out);
}

void AsyncGetFlightInfo(arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback) {
Py_INCREF(context);
OwnedRefNoGIL py_context(context);
auto future = client->GetFlightInfoAsync(options, descriptor);
future.AddCallback([callback, py_context = std::move(py_context)](
arrow::Result<arrow::flight::FlightInfo> result) mutable {
auto status = SafeCallIntoPython([&] {
if (result.ok()) {
callback(py_context.obj(), &result.ValueOrDie(), result.status());
} else {
callback(py_context.obj(), nullptr, result.status());
}
return CheckPyError();
});
ARROW_WARN_NOT_OK(status, "Internal error in async get_flight_info");
});
}

} // namespace flight
} // namespace py
} // namespace arrow
10 changes: 0 additions & 10 deletions python/pyarrow/src/arrow/python/flight.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,16 +345,6 @@ ARROW_PYFLIGHT_EXPORT
Status CreateSchemaResult(const std::shared_ptr<arrow::Schema>& schema,
std::unique_ptr<arrow::flight::SchemaResult>* out);

typedef std::function<void(PyObject* self, arrow::flight::FlightInfo* info,
const Status& status)>
AsyncGetFlightInfoCallback;

ARROW_PYFLIGHT_EXPORT
void AsyncGetFlightInfo(arrow::flight::FlightClient* client,
const arrow::flight::FlightCallOptions& options,
const arrow::flight::FlightDescriptor& descriptor,
PyObject* context, AsyncGetFlightInfoCallback callback);

} // namespace flight
} // namespace py
} // namespace arrow
6 changes: 6 additions & 0 deletions python/pyarrow/src/arrow/python/pyarrow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "arrow/table.h"
#include "arrow/tensor.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"

#include "arrow/python/common.h"
#include "arrow/python/datetime.h"
Expand Down Expand Up @@ -89,6 +90,11 @@ namespace internal {

int check_status(const Status& status) { return ::pyarrow_internal_check_status(status); }

PyObject* convert_status(const Status& status) {
DCHECK(!status.ok());
return ::pyarrow_internal_convert_status(status);
}

} // namespace internal
} // namespace py
} // namespace arrow
5 changes: 5 additions & 0 deletions python/pyarrow/src/arrow/python/pyarrow.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,13 @@ DECLARE_WRAP_FUNCTIONS(table, Table)

namespace internal {

// If status is ok, return 0.
// If status is not ok, set Python error indicator and return -1.
ARROW_PYTHON_EXPORT int check_status(const Status& status);

// Convert status to a Python exception object. Status must not be ok.
ARROW_PYTHON_EXPORT PyObject* convert_status(const Status& status);

} // namespace internal
} // namespace py
} // namespace arrow
Expand Down

0 comments on commit 0bd8c44

Please sign in to comment.