diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index a7e15fe24cc5..8177a187cd90 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::task::Poll; +use std::{pin::Pin, task::Poll}; use crate::{ decode::FlightRecordBatchStream, @@ -28,10 +28,7 @@ use crate::{ use arrow_schema::Schema; use bytes::Bytes; use futures::{ - future::ready, - ready, - stream::{self, BoxStream}, - FutureExt, Stream, StreamExt, TryStreamExt, + channel::oneshot::{Receiver, Sender}, future::ready, stream::{self, BoxStream}, FutureExt, Stream, StreamExt, TryStreamExt }; use prost::Message; use tonic::{metadata::MetadataMap, transport::Channel}; @@ -364,33 +361,18 @@ impl FlightClient { &mut self, request: S, ) -> Result>> { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_put(request).await?.into_inner(); + let response_stream = self.inner.do_put(request).await?.into_inner(); // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + let response_stream = Box::pin(response_stream); + let error_stream = FallibleResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(error_stream.boxed()) @@ -433,33 +415,26 @@ impl FlightClient { &mut self, request: S, ) -> Result { - let (sender, mut receiver) = futures::channel::oneshot::channel(); + let (sender, receiver) = futures::channel::oneshot::channel(); // Intercepts client errors and sends them to the oneshot channel above - let mut request = Box::pin(request); // Pin to heap - let mut sender = Some(sender); // Wrap into Option so can be taken - let request_stream = futures::stream::poll_fn(move |cx| { - Poll::Ready(match ready!(request.poll_next_unpin(cx)) { - Some(Ok(data)) => Some(data), - Some(Err(e)) => { - let _ = sender.take().unwrap().send(e); - None - } - None => None, - }) - }); + let request = Box::pin(request); // Pin to heap + let request_stream = FallibleRequestStream::new(sender, request); let request = self.make_request(request_stream); - let mut response_stream = self.inner.do_exchange(request).await?.into_inner(); + let response_stream = self.inner.do_exchange(request).await?.into_inner(); // Forwards errors from the error oneshot with priority over responses from server - let error_stream = futures::stream::poll_fn(move |cx| { - if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { - return Poll::Ready(Some(Err(err))); - } - let next = ready!(response_stream.poll_next_unpin(cx)); - Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) - }); + // let error_stream = futures::stream::poll_fn(move |cx| { + // if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + // return Poll::Ready(Some(Err(err))); + // } + // let next = ready!(response_stream.poll_next_unpin(cx)); + // Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + // }); + + let response_stream = Box::pin(response_stream); + let error_stream = FallibleResponseStream::new(receiver, response_stream); // combine the response from the server and any error from the client Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) @@ -704,3 +679,75 @@ impl FlightClient { request } } + +struct FallibleRequestStream { + sender: Option>, + request_streams: Pin> + Send + 'static>>, +} + +impl Stream for FallibleRequestStream { + type Item = FlightData; + + fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll> { + let pinned = self.get_mut(); + let mut request_streams = pinned.request_streams.as_mut(); + match request_streams.poll_next_unpin(cx) { + Poll::Ready(s) => { + match s { + Some(Ok(data)) => Poll::Ready(Some(data)), + Some(Err(e)) => { + let _ = pinned.sender.take().unwrap().send(e); + Poll::Ready(None) + }, + None => Poll::Ready(None), + } + + }, + Poll::Pending => Poll::Pending, + } + } +} + +impl FallibleRequestStream { + fn new(sender: Sender, request_streams: Pin> + Send + 'static>>) -> Self { + Self { + sender: Some(sender), + request_streams + } + } +} + +struct FallibleResponseStream { + receiver: Receiver, + response_streams: Pin> + Send + 'static>>, +} + +impl Stream for FallibleResponseStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll> { + let pinned = self.get_mut(); + let receiver = &mut pinned.receiver; + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + }; + + match pinned.response_streams.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(res))) => Poll::Ready(Some(Ok(res))), + Poll::Ready(Some(Err(status))) => Poll::Ready(Some(Err(FlightError::Tonic(status)))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl FallibleResponseStream { + fn new(receiver: Receiver, + response_streams: Pin> + Send + 'static>>) -> Self { + Self { + receiver, + response_streams, + } + } +} +