Skip to content

Commit

Permalink
Refactor to share code between do_put and do_exchange calls
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Kumar <[email protected]>
  • Loading branch information
bitpacker committed May 6, 2024
1 parent 520ad68 commit fb8370f
Showing 1 changed file with 100 additions and 43 deletions.
143 changes: 100 additions & 43 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::task::Poll;
use std::{pin::Pin, task::Poll};

use crate::{
decode::FlightRecordBatchStream,
Expand All @@ -28,6 +28,7 @@ use crate::{
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
channel::oneshot::{Receiver, Sender},
future::ready,
ready,
stream::{self, BoxStream},
Expand Down Expand Up @@ -364,33 +365,18 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = Box::pin(request); // Pin to heap
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();
let response_stream = self.inner.do_put(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
let response_stream = Box::pin(response_stream);
let error_stream = FallibleResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(error_stream.boxed())
Expand Down Expand Up @@ -433,33 +419,26 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = Box::pin(request); // Pin to heap
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();
let response_stream = self.inner.do_exchange(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
// let error_stream = futures::stream::poll_fn(move |cx| {
// if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
// return Poll::Ready(Some(Err(err)));
// }
// let next = ready!(response_stream.poll_next_unpin(cx));
// Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
// });

let response_stream = Box::pin(response_stream);
let error_stream = FallibleResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
Expand Down Expand Up @@ -704,3 +683,81 @@ impl FlightClient {
request
}
}

struct FallibleRequestStream {
sender: Option<Sender<FlightError>>,
request_streams: Pin<Box<dyn Stream<Item = Result<FlightData>> + Send + 'static>>,
}

impl Stream for FallibleRequestStream {
type Item = FlightData;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.request_streams.as_mut();
match ready!(request_streams.poll_next_unpin(cx)) {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
let _ = pinned.sender.take().unwrap().send(e);
Poll::Ready(None)
}
None => Poll::Ready(None),
}
}
}

impl FallibleRequestStream {
fn new(
sender: Sender<FlightError>,
request_streams: Pin<Box<dyn Stream<Item = Result<FlightData>> + Send + 'static>>,
) -> Self {
Self {
sender: Some(sender),
request_streams,
}
}
}

struct FallibleResponseStream<T> {
receiver: Receiver<FlightError>,
response_streams:
Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>,
}

impl<T> Stream for FallibleResponseStream<T> {
type Item = Result<T>;

fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};

match ready!(pinned.response_streams.poll_next_unpin(cx)) {
Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
None => Poll::Ready(None),
}
}
}

impl<T> FallibleResponseStream<T> {
fn new(
receiver: Receiver<FlightError>,
response_streams: Pin<
Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>,
>,
) -> Self {
Self {
receiver,
response_streams,
}
}
}

0 comments on commit fb8370f

Please sign in to comment.