From fb8370f1fb26f2e8938cf39cbf979a7e692eb588 Mon Sep 17 00:00:00 2001 From: Praveen Kumar Date: Mon, 6 May 2024 16:58:18 +0100 Subject: [PATCH] Refactor to share code between do_put and do_exchange calls Signed-off-by: Praveen Kumar --- arrow-flight/src/client.rs | 143 ++++++++++++++++++++++++++----------- 1 file changed, 100 insertions(+), 43 deletions(-) diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index a7e15fe24cc5..f534b838e13e 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::task::Poll; +use std::{pin::Pin, task::Poll}; use crate::{ decode::FlightRecordBatchStream, @@ -28,6 +28,7 @@ use crate::{ use arrow_schema::Schema; use bytes::Bytes; use futures::{ + channel::oneshot::{Receiver, Sender}, future::ready, ready, stream::{self, BoxStream}, @@ -364,33 +365,18 @@ impl FlightClient { &mut self, request: S, ) -> Result>> { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_put(request).await?.into_inner(); + let response_stream = self.inner.do_put(request).await?.into_inner(); // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + let response_stream = Box::pin(response_stream); + let error_stream = FallibleResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(error_stream.boxed()) @@ -433,33 +419,26 @@ impl FlightClient { &mut self, request: S, ) -> Result { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_exchange(request).await?.into_inner(); + let response_stream = self.inner.do_exchange(request).await?.into_inner(); // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + // let error_stream = futures::stream::poll_fn(move |cx| { + // if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + // return Poll::Ready(Some(Err(err))); + // } + // let next = ready!(response_stream.poll_next_unpin(cx)); + // Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + // }); + + let response_stream = Box::pin(response_stream); + let error_stream = FallibleResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) @@ -704,3 +683,81 @@ impl FlightClient { request } } + +struct FallibleRequestStream { + sender: Option>, + request_streams: Pin> + Send + 'static>>, +} + +impl Stream for FallibleRequestStream { + type Item = FlightData; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let pinned = self.get_mut(); + let mut request_streams = pinned.request_streams.as_mut(); + match ready!(request_streams.poll_next_unpin(cx)) { + Some(Ok(data)) => Poll::Ready(Some(data)), + Some(Err(e)) => { + let _ = pinned.sender.take().unwrap().send(e); + Poll::Ready(None) + } + None => Poll::Ready(None), + } + } +} + +impl FallibleRequestStream { + fn new( + sender: Sender, + request_streams: Pin> + Send + 'static>>, + ) -> Self { + Self { + sender: Some(sender), + request_streams, + } + } +} + +struct FallibleResponseStream { + receiver: Receiver, + response_streams: + Pin> + Send + 'static>>, +} + +impl Stream for FallibleResponseStream { + type Item = Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let pinned = self.get_mut(); + let receiver = &mut pinned.receiver; + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + }; + + match ready!(pinned.response_streams.poll_next_unpin(cx)) { + Some(Ok(res)) => Poll::Ready(Some(Ok(res))), + Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), + None => Poll::Ready(None), + } + } +} + +impl FallibleResponseStream { + fn new( + receiver: Receiver, + response_streams: Pin< + Box> + Send + 'static>, + >, + ) -> Self { + Self { + receiver, + response_streams, + } + } +}