diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index b2abfb0c17b2..cdc09fecaede 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -34,7 +34,7 @@ use futures::{ FutureExt, Stream, StreamExt, TryStreamExt, }; use prost::Message; -use tonic::{metadata::MetadataMap, transport::Channel}; +use tonic::{metadata::MetadataMap, transport::Channel, Status}; use crate::error::{FlightError, Result}; @@ -417,9 +417,7 @@ impl FlightClient { /// /// // encode the batch as a stream of `FlightData` /// let flight_data_stream = FlightDataEncoderBuilder::new() - /// .build(futures::stream::iter(vec![Ok(batch)])) - /// // data encoder return Results, but do_exchange requires FlightData - /// .map(|batch|batch.unwrap()); + /// .build(futures::stream::iter(vec![Ok(batch)])); /// /// // send the stream and get the results as `RecordBatches` /// let response: Vec = client @@ -431,20 +429,40 @@ impl FlightClient { /// .expect("error calling do_exchange"); /// # } /// ``` - pub async fn do_exchange + Send + 'static>( + pub async fn do_exchange> + Send + 'static>( &mut self, request: S, ) -> Result { - let request = self.make_request(request); + let (sender, mut receiver) = futures::channel::oneshot::channel(); - let response = self - .inner - .do_exchange(request) - .await? - .into_inner() - .map_err(FlightError::Tonic); + // Intercepts client errors and sends them to the oneshot channel above + let mut request = Box::pin(request); // Pin to heap + let mut sender = Some(sender); // Wrap into Option so can be taken + let request_stream = futures::stream::poll_fn(move |cx| { + Poll::Ready(match ready!(request.poll_next_unpin(cx)) { + Some(Ok(data)) => Some(data), + Some(Err(e)) => { + let _ = sender.take().unwrap().send(e); + None + } + None => None, + }) + }); + + let request = self.make_request(request_stream); + let mut response_stream = self.inner.do_exchange(request).await?.into_inner(); + + // Forwards errors from the error oneshot with priority over responses from server + let error_stream = futures::stream::poll_fn(move |cx| -> Poll>> { + if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) { + return Poll::Ready(Some(Err(err))); + } + let next = ready!(response_stream.poll_next_unpin(cx)); + Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic))) + }); - Ok(FlightRecordBatchStream::new_from_flight_data(response)) + // combine the response from the server and any error from the client + Ok(FlightRecordBatchStream::new_from_flight_data(error_stream)) } /// Make a `ListFlights` call to the server with the provided diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 47565334cb63..ddfb7bf8fc6c 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -493,7 +493,7 @@ async fn test_do_exchange() { .set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect()); let response_stream = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await .expect("error making request"); @@ -528,7 +528,7 @@ async fn test_do_exchange_error() { let input_flight_data = test_flight_data().await; let response = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await; let response = match response { Ok(_) => panic!("unexpected success"), @@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() { test_server.set_do_exchange_response(response); let response_stream = client - .do_exchange(futures::stream::iter(input_flight_data.clone())) + .do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok)) .await .expect("error making request"); @@ -593,6 +593,92 @@ async fn test_do_exchange_error_stream() { .await; } +#[tokio::test] +async fn test_do_exchange_error_stream_client() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e = Status::invalid_argument("bad arg: client"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e.clone(), + ))])); + + let output_flight_data = FlightData::new() + .with_descriptor(FlightDescriptor::new_cmd("Sample command")) + .with_data_body("body".as_bytes()) + .with_data_header("header".as_bytes()) + .with_app_metadata("metadata".as_bytes()); + + // server responds with one good message + let response = vec![Ok(output_flight_data)]; + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client + expect_status(response, e); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_exchange_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; + +} + +#[tokio::test] +async fn test_do_exchange_error_client_and_server() { + do_test(|test_server, mut client| async move { + client.add_header("foo-header", "bar-header-value").unwrap(); + + let e_client = Status::invalid_argument("bad arg: client"); + let e_server = Status::invalid_argument("bad arg: server"); + + // input stream to client sends good FlightData followed by an error + let input_flight_data = test_flight_data().await; + let input_stream = futures::stream::iter(input_flight_data.clone()) + .map(Ok) + .chain(futures::stream::iter(vec![Err(FlightError::from( + e_client.clone(), + ))])); + + // server responds with an error (e.g. because it got truncated data) + let response = vec![Err(e_server)]; + test_server.set_do_exchange_response(response); + + let response_stream = client + .do_exchange(input_stream) + .await + .expect("error making request"); + + let response: Result, _> = response_stream.try_collect().await; + let response = match response { + Ok(_) => panic!("unexpected success"), + Err(e) => e, + }; + + // expect to the error made from the client (not the server) + expect_status(response, e_client); + // server still got the request messages until the client sent the error + assert_eq!(test_server.take_do_exchange_request(), Some(input_flight_data)); + ensure_metadata(&client, &test_server); + }) + .await; +} + #[tokio::test] async fn test_get_schema() { do_test(|test_server, mut client| async move {