Skip to content

Commit

Permalink
Refactor to share code between do_put and do_exchange calls
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Kumar <[email protected]>
  • Loading branch information
bitpacker committed May 6, 2024
1 parent 520ad68 commit d97a935
Showing 1 changed file with 94 additions and 47 deletions.
141 changes: 94 additions & 47 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::task::Poll;
use std::{pin::Pin, task::Poll};

use crate::{
decode::FlightRecordBatchStream,
Expand All @@ -28,10 +28,7 @@ use crate::{
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
ready,
stream::{self, BoxStream},
FutureExt, Stream, StreamExt, TryStreamExt,
channel::oneshot::{Receiver, Sender}, future::ready, stream::{self, BoxStream}, FutureExt, Stream, StreamExt, TryStreamExt
};
use prost::Message;
use tonic::{metadata::MetadataMap, transport::Channel};
Expand Down Expand Up @@ -364,33 +361,18 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = Box::pin(request); // Pin to heap
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();
let response_stream = self.inner.do_put(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
let response_stream = Box::pin(response_stream);
let error_stream = FallibleResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(error_stream.boxed())
Expand Down Expand Up @@ -433,33 +415,26 @@ impl FlightClient {
&mut self,
request: S,
) -> Result<FlightRecordBatchStream> {
let (sender, mut receiver) = futures::channel::oneshot::channel();
let (sender, receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});
let request = Box::pin(request); // Pin to heap
let request_stream = FallibleRequestStream::new(sender, request);

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();
let response_stream = self.inner.do_exchange(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});
// let error_stream = futures::stream::poll_fn(move |cx| {
// if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
// return Poll::Ready(Some(Err(err)));
// }
// let next = ready!(response_stream.poll_next_unpin(cx));
// Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
// });

let response_stream = Box::pin(response_stream);
let error_stream = FallibleResponseStream::new(receiver, response_stream);

// combine the response from the server and any error from the client
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
Expand Down Expand Up @@ -704,3 +679,75 @@ impl FlightClient {
request
}
}

struct FallibleRequestStream {
sender: Option<Sender<FlightError>>,
request_streams: Pin<Box<dyn Stream<Item = Result<FlightData>> + Send + 'static>>,
}

impl Stream for FallibleRequestStream {
type Item = FlightData;

fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let mut request_streams = pinned.request_streams.as_mut();
match request_streams.poll_next_unpin(cx) {
Poll::Ready(s) => {
match s {
Some(Ok(data)) => Poll::Ready(Some(data)),
Some(Err(e)) => {
let _ = pinned.sender.take().unwrap().send(e);
Poll::Ready(None)
},
None => Poll::Ready(None),
}

},
Poll::Pending => Poll::Pending,
}
}
}

impl FallibleRequestStream {
fn new(sender: Sender<FlightError>, request_streams: Pin<Box<dyn Stream<Item = Result<FlightData>> + Send + 'static>>) -> Self {
Self {
sender: Some(sender),
request_streams
}
}
}

struct FallibleResponseStream<T> {
receiver: Receiver<FlightError>,
response_streams: Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>,
}

impl <T> Stream for FallibleResponseStream<T> {
type Item = Result<T>;

fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
let pinned = self.get_mut();
let receiver = &mut pinned.receiver;
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
};

match pinned.response_streams.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(res))) => Poll::Ready(Some(Ok(res))),
Poll::Ready(Some(Err(status))) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}

impl <T> FallibleResponseStream<T> {
fn new(receiver: Receiver<FlightError>,
response_streams: Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>) -> Self {
Self {
receiver,
response_streams,
}
}
}

0 comments on commit d97a935

Please sign in to comment.