Skip to content

Commit eb2d00b

Browse files
Fallible stream for arrow-flight do_exchange call (#3462) (#5698)
Signed-off-by: Praveen Kumar <[email protected]> Co-authored-by: Praveen Kumar <[email protected]>
1 parent 6348dc3 commit eb2d00b

File tree

2 files changed

+124
-15
lines changed

2 files changed

+124
-15
lines changed

arrow-flight/src/client.rs

+30-12
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,7 @@ impl FlightClient {
417417
///
418418
/// // encode the batch as a stream of `FlightData`
419419
/// let flight_data_stream = FlightDataEncoderBuilder::new()
420-
/// .build(futures::stream::iter(vec![Ok(batch)]))
421-
/// // data encoder return Results, but do_exchange requires FlightData
422-
/// .map(|batch|batch.unwrap());
420+
/// .build(futures::stream::iter(vec![Ok(batch)]));
423421
///
424422
/// // send the stream and get the results as `RecordBatches`
425423
/// let response: Vec<RecordBatch> = client
@@ -431,20 +429,40 @@ impl FlightClient {
431429
/// .expect("error calling do_exchange");
432430
/// # }
433431
/// ```
434-
pub async fn do_exchange<S: Stream<Item = FlightData> + Send + 'static>(
432+
pub async fn do_exchange<S: Stream<Item = Result<FlightData>> + Send + 'static>(
435433
&mut self,
436434
request: S,
437435
) -> Result<FlightRecordBatchStream> {
438-
let request = self.make_request(request);
436+
let (sender, mut receiver) = futures::channel::oneshot::channel();
439437

440-
let response = self
441-
.inner
442-
.do_exchange(request)
443-
.await?
444-
.into_inner()
445-
.map_err(FlightError::Tonic);
438+
// Intercepts client errors and sends them to the oneshot channel above
439+
let mut request = Box::pin(request); // Pin to heap
440+
let mut sender = Some(sender); // Wrap into Option so can be taken
441+
let request_stream = futures::stream::poll_fn(move |cx| {
442+
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
443+
Some(Ok(data)) => Some(data),
444+
Some(Err(e)) => {
445+
let _ = sender.take().unwrap().send(e);
446+
None
447+
}
448+
None => None,
449+
})
450+
});
451+
452+
let request = self.make_request(request_stream);
453+
let mut response_stream = self.inner.do_exchange(request).await?.into_inner();
454+
455+
// Forwards errors from the error oneshot with priority over responses from server
456+
let error_stream = futures::stream::poll_fn(move |cx| {
457+
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
458+
return Poll::Ready(Some(Err(err)));
459+
}
460+
let next = ready!(response_stream.poll_next_unpin(cx));
461+
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
462+
});
446463

447-
Ok(FlightRecordBatchStream::new_from_flight_data(response))
464+
// combine the response from the server and any error from the client
465+
Ok(FlightRecordBatchStream::new_from_flight_data(error_stream))
448466
}
449467

450468
/// Make a `ListFlights` call to the server with the provided

arrow-flight/tests/client.rs

+94-3
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ async fn test_do_exchange() {
493493
.set_do_exchange_response(output_flight_data.clone().into_iter().map(Ok).collect());
494494

495495
let response_stream = client
496-
.do_exchange(futures::stream::iter(input_flight_data.clone()))
496+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
497497
.await
498498
.expect("error making request");
499499

@@ -528,7 +528,7 @@ async fn test_do_exchange_error() {
528528
let input_flight_data = test_flight_data().await;
529529

530530
let response = client
531-
.do_exchange(futures::stream::iter(input_flight_data.clone()))
531+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
532532
.await;
533533
let response = match response {
534534
Ok(_) => panic!("unexpected success"),
@@ -572,7 +572,7 @@ async fn test_do_exchange_error_stream() {
572572
test_server.set_do_exchange_response(response);
573573

574574
let response_stream = client
575-
.do_exchange(futures::stream::iter(input_flight_data.clone()))
575+
.do_exchange(futures::stream::iter(input_flight_data.clone()).map(Ok))
576576
.await
577577
.expect("error making request");
578578

@@ -593,6 +593,97 @@ async fn test_do_exchange_error_stream() {
593593
.await;
594594
}
595595

596+
#[tokio::test]
597+
async fn test_do_exchange_error_stream_client() {
598+
do_test(|test_server, mut client| async move {
599+
client.add_header("foo-header", "bar-header-value").unwrap();
600+
601+
let e = Status::invalid_argument("bad arg: client");
602+
603+
// input stream to client sends good FlightData followed by an error
604+
let input_flight_data = test_flight_data().await;
605+
let input_stream = futures::stream::iter(input_flight_data.clone())
606+
.map(Ok)
607+
.chain(futures::stream::iter(vec![Err(FlightError::from(
608+
e.clone(),
609+
))]));
610+
611+
let output_flight_data = FlightData::new()
612+
.with_descriptor(FlightDescriptor::new_cmd("Sample command"))
613+
.with_data_body("body".as_bytes())
614+
.with_data_header("header".as_bytes())
615+
.with_app_metadata("metadata".as_bytes());
616+
617+
// server responds with one good message
618+
let response = vec![Ok(output_flight_data)];
619+
test_server.set_do_exchange_response(response);
620+
621+
let response_stream = client
622+
.do_exchange(input_stream)
623+
.await
624+
.expect("error making request");
625+
626+
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
627+
let response = match response {
628+
Ok(_) => panic!("unexpected success"),
629+
Err(e) => e,
630+
};
631+
632+
// expect to the error made from the client
633+
expect_status(response, e);
634+
// server still got the request messages until the client sent the error
635+
assert_eq!(
636+
test_server.take_do_exchange_request(),
637+
Some(input_flight_data)
638+
);
639+
ensure_metadata(&client, &test_server);
640+
})
641+
.await;
642+
}
643+
644+
#[tokio::test]
645+
async fn test_do_exchange_error_client_and_server() {
646+
do_test(|test_server, mut client| async move {
647+
client.add_header("foo-header", "bar-header-value").unwrap();
648+
649+
let e_client = Status::invalid_argument("bad arg: client");
650+
let e_server = Status::invalid_argument("bad arg: server");
651+
652+
// input stream to client sends good FlightData followed by an error
653+
let input_flight_data = test_flight_data().await;
654+
let input_stream = futures::stream::iter(input_flight_data.clone())
655+
.map(Ok)
656+
.chain(futures::stream::iter(vec![Err(FlightError::from(
657+
e_client.clone(),
658+
))]));
659+
660+
// server responds with an error (e.g. because it got truncated data)
661+
let response = vec![Err(e_server)];
662+
test_server.set_do_exchange_response(response);
663+
664+
let response_stream = client
665+
.do_exchange(input_stream)
666+
.await
667+
.expect("error making request");
668+
669+
let response: Result<Vec<_>, _> = response_stream.try_collect().await;
670+
let response = match response {
671+
Ok(_) => panic!("unexpected success"),
672+
Err(e) => e,
673+
};
674+
675+
// expect to the error made from the client (not the server)
676+
expect_status(response, e_client);
677+
// server still got the request messages until the client sent the error
678+
assert_eq!(
679+
test_server.take_do_exchange_request(),
680+
Some(input_flight_data)
681+
);
682+
ensure_metadata(&client, &test_server);
683+
})
684+
.await;
685+
}
686+
596687
#[tokio::test]
597688
async fn test_get_schema() {
598689
do_test(|test_server, mut client| async move {

0 commit comments

Comments
 (0)