Skip to content

Commit

Permalink
Refactor to share code between do_put and do_exchange calls
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Kumar <[email protected]>
  • Loading branch information
bitpacker committed May 7, 2024
1 parent 520ad68 commit 7729b20
Showing 1 changed file with 110 additions and 44 deletions.
154 changes: 110 additions & 44 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::task::Poll;
use std::{pin::Pin, task::Poll};

use crate::{
decode::FlightRecordBatchStream,
Expand All @@ -28,6 +28,7 @@ use crate::{
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
channel::oneshot::{Receiver, Sender},
future::ready,
ready,
stream::{self, BoxStream},
Expand Down Expand Up @@ -364,33 +365,18 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = Box::pin(request); // Pin to heap
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();
let response_stream = self.inner.do_put(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(error_stream.boxed())
Expand Down Expand Up @@ -433,33 +419,17 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

let request = Box::pin(request);
// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();
let response_stream = self.inner.do_exchange(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
let response_stream = Box::pin(response_stream);
let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
Expand Down Expand Up @@ -704,3 +674,99 @@ impl FlightClient {
request
}
}

/// Wrapper around fallible stream such that when
/// it encounters an error it uses the oneshot sender to
/// notify the error and stop any further streaming. See `do_put` or
/// `do_exchange` for it's uses.
struct FallibleRequestStream<T, E> {
/// sender to notify error
sender: Option<Sender<E>>,
/// fallible stream
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
}

impl<T, E> FallibleRequestStream<T, E> {
fn new(
sender: Sender<E>,
fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
) -> Self {
Self {
sender: Some(sender),
fallible_stream,
}
}
}

impl<T, E> Stream for FallibleRequestStream<T, E> {
type Item = T;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.fallible_stream.as_mut();
match ready!(request_streams.poll_next_unpin(cx)) {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
// unwrap() here is safe, ownership of sender will
// be moved only once as this stream will not be polled
// again
let _ = pinned.sender.take().unwrap().send(e);
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}

/// Wrapper for a tonic response stream that can produce a tonic
/// error. This is tied to a oneshot receiver which can be notified
/// of other errors. When it receives an error through receiver
/// end, it prioritises that error to be sent back. See `do_put` or
/// `do_exchange` for it's uses
struct FallibleTonicResponseStream<T> {
/// Receiver for FlightError
receiver: Receiver<FlightError>,
/// Tonic response stream
response_stream:
Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>,
}

impl<T> FallibleTonicResponseStream<T> {
fn new(
receiver: Receiver<FlightError>,
response_stream: Pin<
Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>,
>,
) -> Self {
Self {
receiver,
response_stream,
}
}
}

impl<T> Stream for FallibleTonicResponseStream<T> {
type Item = Result<T>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
// Prioritise sending the error that's been notified over
// polling the response_stream
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};

match ready!(pinned.response_stream.poll_next_unpin(cx)) {
Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
None => Poll::Ready(None),
}
}
}

0 comments on commit 7729b20

Please sign in to comment.