From 9b8b42b522a80c0f1cca9f80d209b48b53b7b849 Mon Sep 17 00:00:00 2001
From: Praveen Kumar <praveen@bit2byte.net>
Date: Mon, 6 May 2024 16:58:18 +0100
Subject: [PATCH] Refactor to share code between do_put and do_exchange calls

Signed-off-by: Praveen Kumar <praveen@bit2byte.net>
---
 arrow-flight/src/client.rs | 154 ++++++++++++++++++++++++++-----------
 1 file changed, 110 insertions(+), 44 deletions(-)

diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs
index a7e15fe24cc5..f5f28f683a9d 100644
--- a/arrow-flight/src/client.rs
+++ b/arrow-flight/src/client.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::task::Poll;
+use std::{pin::Pin, task::Poll};
 
 use crate::{
     decode::FlightRecordBatchStream,
@@ -28,6 +28,7 @@ use crate::{
 use arrow_schema::Schema;
 use bytes::Bytes;
 use futures::{
+    channel::oneshot::{Receiver, Sender},
     future::ready,
     ready,
     stream::{self, BoxStream},
@@ -364,33 +365,18 @@ impl FlightClient {
         &mut self,
         request: S,
     ) -> Result<BoxStream<'static, Result<PutResult>>> {
-        let (sender, mut receiver) = futures::channel::oneshot::channel();
+        let (sender, receiver) = futures::channel::oneshot::channel();
 
         // Intercepts client errors and sends them to the oneshot channel above
-        let mut request = Box::pin(request); // Pin to heap
-        let mut sender = Some(sender); // Wrap into Option so can be taken
-        let request_stream = futures::stream::poll_fn(move |cx| {
-            Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
-                Some(Ok(data)) => Some(data),
-                Some(Err(e)) => {
-                    let _ = sender.take().unwrap().send(e);
-                    None
-                }
-                None => None,
-            })
-        });
+        let request = Box::pin(request); // Pin to heap
+        let request_stream = FallibleRequestStream::new(sender, request);
 
         let request = self.make_request(request_stream);
-        let mut response_stream = self.inner.do_put(request).await?.into_inner();
+        let response_stream = self.inner.do_put(request).await?.into_inner();
 
         // Forwards errors from the error oneshot with priority over responses from server
-        let error_stream = futures::stream::poll_fn(move |cx| {
-            if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
-                return Poll::Ready(Some(Err(err)));
-            }
-            let next = ready!(response_stream.poll_next_unpin(cx));
-            Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
-        });
+        let response_stream = Box::pin(response_stream);
+        let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
 
         // combine the response from the server and any error from the client
         Ok(error_stream.boxed())
@@ -433,33 +419,17 @@ impl FlightClient {
         &mut self,
         request: S,
     ) -> Result<FlightRecordBatchStream> {
-        let (sender, mut receiver) = futures::channel::oneshot::channel();
+        let (sender, receiver) = futures::channel::oneshot::channel();
 
+        let request = Box::pin(request);
         // Intercepts client errors and sends them to the oneshot channel above
-        let mut request = Box::pin(request); // Pin to heap
-        let mut sender = Some(sender); // Wrap into Option so can be taken
-        let request_stream = futures::stream::poll_fn(move |cx| {
-            Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
-                Some(Ok(data)) => Some(data),
-                Some(Err(e)) => {
-                    let _ = sender.take().unwrap().send(e);
-                    None
-                }
-                None => None,
-            })
-        });
+        let request_stream = FallibleRequestStream::new(sender, request);
 
         let request = self.make_request(request_stream);
-        let mut response_stream = self.inner.do_exchange(request).await?.into_inner();
+        let response_stream = self.inner.do_exchange(request).await?.into_inner();
 
-        // Forwards errors from the error oneshot with priority over responses from server
-        let error_stream = futures::stream::poll_fn(move |cx| {
-            if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
-                return Poll::Ready(Some(Err(err)));
-            }
-            let next = ready!(response_stream.poll_next_unpin(cx));
-            Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
-        });
+        let response_stream = Box::pin(response_stream);
+        let error_stream = FallibleTonicResponseStream::new(receiver, response_stream);
 
         // combine the response from the server and any error from the client
         Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
@@ -704,3 +674,99 @@ impl FlightClient {
         request
     }
 }
+
+/// Wrapper around fallible stream such that when
+/// it encounters an error it uses the oneshot sender to
+/// notify the error and stop any further streaming. See `do_put` or
+/// `do_exchange` for it's uses.
+struct FallibleRequestStream<T, E> {
+    /// sender to notify error
+    sender: Option<Sender<E>>,
+    /// fallible stream
+    fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
+}
+
+impl<T, E> FallibleRequestStream<T, E> {
+    fn new(
+        sender: Sender<E>,
+        fallible_stream: Pin<Box<dyn Stream<Item = std::result::Result<T, E>> + Send + 'static>>,
+    ) -> Self {
+        Self {
+            sender: Some(sender),
+            fallible_stream,
+        }
+    }
+}
+
+impl<T, E> Stream for FallibleRequestStream<T, E> {
+    type Item = T;
+
+    fn poll_next(
+        self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Option<Self::Item>> {
+        let pinned = self.get_mut();
+        let mut request_streams = pinned.fallible_stream.as_mut();
+        match ready!(request_streams.poll_next_unpin(cx)) {
+            Some(Ok(data)) => Poll::Ready(Some(data)),
+            Some(Err(e)) => {
+                // unwrap() here is safe, ownership of sender will
+                // be moved only once as this stream will not be polled
+                // again
+                let _ = pinned.sender.take().unwrap().send(e);
+                Poll::Ready(None)
+            }
+            None => Poll::Ready(None),
+        }
+    }
+}
+
+/// Wrapper for a tonic response stream that can produce a tonic
+/// error. This is tied to a oneshot receiver which can be notified
+/// of other errors. When it receives an error through receiver
+/// end, it prioritises that error to be sent back. See `do_put` or
+/// `do_exchange` for it's uses
+struct FallibleTonicResponseStream<T> {
+    /// Receiver for FlightError
+    receiver: Receiver<FlightError>,
+    /// Tonic response stream
+    response_stream:
+        Pin<Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>>,
+}
+
+impl<T> FallibleTonicResponseStream<T> {
+    fn new(
+        receiver: Receiver<FlightError>,
+        response_stream: Pin<
+            Box<dyn Stream<Item = std::result::Result<T, tonic::Status>> + Send + 'static>,
+        >,
+    ) -> Self {
+        Self {
+            receiver,
+            response_stream,
+        }
+    }
+}
+
+impl<T> Stream for FallibleTonicResponseStream<T> {
+    type Item = Result<T>;
+
+    fn poll_next(
+        self: Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let pinned = self.get_mut();
+        let receiver = &mut pinned.receiver;
+        // Prioritise sending the error that's been notified over
+        // polling the response_stream
+        if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
+            return Poll::Ready(Some(Err(err)));
+        };
+
+        match ready!(pinned.response_stream.poll_next_unpin(cx)) {
+            Some(Ok(res)) => Poll::Ready(Some(Ok(res))),
+            Some(Err(status)) => Poll::Ready(Some(Err(FlightError::Tonic(status)))),
+            None => Poll::Ready(None),
+        }
+    }
+}