Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RecordBatch <--> FlightData encode/decode + tests #3391

Merged
merged 18 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arrow-flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ license = "Apache-2.0"
[dependencies]
arrow-array = { version = "30.0.0", path = "../arrow-array" }
arrow-buffer = { version = "30.0.0", path = "../arrow-buffer" }
# Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389
arrow-cast = { version = "30.0.0", path = "../arrow-cast" }
arrow-ipc = { version = "30.0.0", path = "../arrow-ipc" }
arrow-schema = { version = "30.0.0", path = "../arrow-schema" }
base64 = { version = "0.20", default-features = false, features = ["std"] }
Expand Down
325 changes: 16 additions & 309 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
// under the License.

use crate::{
flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch,
FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient,
FlightDescriptor, FlightInfo, HandshakeRequest, Ticket,
};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::Schema;
use bytes::Bytes;
use futures::{future::ready, ready, stream, StreamExt};
use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, task::Poll};
use tonic::{metadata::MetadataMap, transport::Channel, Streaming};
use futures::{future::ready, stream, StreamExt, TryStreamExt};
use tonic::{metadata::MetadataMap, transport::Channel};

use crate::error::{FlightError, Result};

Expand Down Expand Up @@ -161,7 +158,7 @@ impl FlightClient {

/// Make a `DoGet` call to the server with the provided ticket,
/// returning a [`FlightRecordBatchStream`] for reading
/// [`RecordBatch`]es.
/// [`RecordBatch`](arrow_array::RecordBatch)es.
///
/// # Example:
/// ```no_run
Expand Down Expand Up @@ -197,10 +194,17 @@ impl FlightClient {
pub async fn do_get(&mut self, ticket: Ticket) -> Result<FlightRecordBatchStream> {
let request = self.make_request(ticket);

let response = self.inner.do_get(request).await?.into_inner();

let flight_data_stream = FlightDataStream::new(response);
Ok(FlightRecordBatchStream::new(flight_data_stream))
let response_stream = self
.inner
.do_get(request)
.await?
.into_inner()
// convert to FlightError
.map_err(|e| e.into());

Ok(FlightRecordBatchStream::new_from_flight_data(
response_stream,
))
}

/// Make a `GetFlightInfo` call to the server with the provided
Expand Down Expand Up @@ -268,300 +272,3 @@ impl FlightClient {
request
}
}

/// A stream of [`RecordBatch`]es from from an Arrow Flight server.
///
/// To access the lower level Flight messages directly, consider
/// calling [`Self::into_inner`] and using the [`FlightDataStream`]
/// directly.
#[derive(Debug)]
pub struct FlightRecordBatchStream {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is moved / renamed / tested in decode.rs

inner: FlightDataStream,
got_schema: bool,
}

impl FlightRecordBatchStream {
pub fn new(inner: FlightDataStream) -> Self {
Self {
inner,
got_schema: false,
}
}

/// Has a message defining the schema been received yet?
pub fn got_schema(&self) -> bool {
self.got_schema
}

/// Consume self and return the wrapped [`FlightDataStream`]
pub fn into_inner(self) -> FlightDataStream {
self.inner
}
}
impl futures::Stream for FlightRecordBatchStream {
type Item = Result<RecordBatch>;

/// Returns the next [`RecordBatch`] available in this stream, or `None` if
/// there are no further results available.
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
loop {
let res = ready!(self.inner.poll_next_unpin(cx));
match res {
// Inner exhausted
None => {
return Poll::Ready(None);
}
Some(Err(e)) => {
return Poll::Ready(Some(Err(e)));
}
// translate data
Some(Ok(data)) => match data.payload {
DecodedPayload::Schema(_) if self.got_schema => {
return Poll::Ready(Some(Err(FlightError::protocol(
"Unexpectedly saw multiple Schema messages in FlightData stream",
))));
}
DecodedPayload::Schema(_) => {
self.got_schema = true;
// Need next message, poll inner again
}
DecodedPayload::RecordBatch(batch) => {
return Poll::Ready(Some(Ok(batch)));
}
DecodedPayload::None => {
// Need next message
}
},
}
}
}
}

/// Wrapper around a stream of [`FlightData`] that handles the details
/// of decoding low level Flight messages into [`Schema`] and
/// [`RecordBatch`]es, including details such as dictionaries.
///
/// # Protocol Details
///
/// The client handles flight messages as followes:
///
/// - **None:** This message has no effect. This is useful to
/// transmit metadata without any actual payload.
///
/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and
/// the decoded schema is returned.
///
/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing
/// dictionary for the same column will be overwritten. This
/// message is NOT visible.
///
/// - **Record Batch:** Record batch is created based on the current
/// schema and dictionaries. This fails if no schema was transmitted
/// yet.
///
/// All other message types (at the time of writing: e.g. tensor and
/// sparse tensor) lead to an error.
///
/// Example usecases
///
/// 1. Using this low level stream it is possible to receive a steam
/// of RecordBatches in FlightData that have different schemas by
/// handling multiple schema messages separately.
#[derive(Debug)]
pub struct FlightDataStream {
/// Underlying data stream
response: Streaming<FlightData>,
/// Decoding state
state: Option<FlightStreamState>,
/// seen the end of the inner stream?
done: bool,
}

impl FlightDataStream {
/// Create a new wrapper around the stream of FlightData
pub fn new(response: Streaming<FlightData>) -> Self {
Self {
state: None,
response,
done: false,
}
}

/// Extracts flight data from the next message, updating decoding
/// state as necessary.
fn extract_message(&mut self, data: FlightData) -> Result<Option<DecodedFlightData>> {
use arrow_ipc::MessageHeader;
let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| {
FlightError::DecodeError(format!("Error decoding root message: {e}"))
})?;

match message.header_type() {
MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))),
MessageHeader::Schema => {
let schema = Schema::try_from(&data).map_err(|e| {
FlightError::DecodeError(format!("Error decoding schema: {e}"))
})?;

let schema = Arc::new(schema);
let dictionaries_by_field = HashMap::new();

self.state = Some(FlightStreamState {
schema: Arc::clone(&schema),
dictionaries_by_field,
});
Ok(Some(DecodedFlightData::new_schema(data, schema)))
}
MessageHeader::DictionaryBatch => {
let state = if let Some(state) = self.state.as_mut() {
state
} else {
return Err(FlightError::protocol(
"Received DictionaryBatch prior to Schema",
));
};

let buffer: arrow_buffer::Buffer = data.data_body.into();
let dictionary_batch =
message.header_as_dictionary_batch().ok_or_else(|| {
FlightError::protocol(
"Could not get dictionary batch from DictionaryBatch message",
)
})?;

arrow_ipc::reader::read_dictionary(
&buffer,
dictionary_batch,
&state.schema,
&mut state.dictionaries_by_field,
&message.version(),
)
.map_err(|e| {
FlightError::DecodeError(format!(
"Error decoding ipc dictionary: {e}"
))
})?;

// Updated internal state, but no decoded message
Ok(None)
}
MessageHeader::RecordBatch => {
let state = if let Some(state) = self.state.as_ref() {
state
} else {
return Err(FlightError::protocol(
"Received RecordBatch prior to Schema",
));
};

let batch = flight_data_to_arrow_batch(
&data,
Arc::clone(&state.schema),
&state.dictionaries_by_field,
)
.map_err(|e| {
FlightError::DecodeError(format!(
"Error decoding ipc RecordBatch: {e}"
))
})?;

Ok(Some(DecodedFlightData::new_record_batch(data, batch)))
}
other => {
let name = other.variant_name().unwrap_or("UNKNOWN");
Err(FlightError::protocol(format!("Unexpected message: {name}")))
}
}
}
}

impl futures::Stream for FlightDataStream {
type Item = Result<DecodedFlightData>;
/// Returns the result of decoding the next [`FlightData`] message
/// from the server, or `None` if there are no further results
/// available.
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
loop {
let res = ready!(self.response.poll_next_unpin(cx));

return Poll::Ready(match res {
None => {
self.done = true;
None // inner is exhausted
}
Some(data) => Some(match data {
Err(e) => Err(FlightError::Tonic(e)),
Ok(data) => match self.extract_message(data) {
Ok(Some(extracted)) => Ok(extracted),
Ok(None) => continue, // Need next input message
Err(e) => Err(e),
},
}),
});
}
}
}

/// tracks the state needed to reconstruct [`RecordBatch`]es from a
/// streaming flight response.
#[derive(Debug)]
struct FlightStreamState {
schema: Arc<Schema>,
dictionaries_by_field: HashMap<i64, ArrayRef>,
}

/// FlightData and the decoded payload (Schema, RecordBatch), if any
#[derive(Debug)]
pub struct DecodedFlightData {
pub inner: FlightData,
pub payload: DecodedPayload,
}

impl DecodedFlightData {
pub fn new_none(inner: FlightData) -> Self {
Self {
inner,
payload: DecodedPayload::None,
}
}

pub fn new_schema(inner: FlightData, schema: Arc<Schema>) -> Self {
Self {
inner,
payload: DecodedPayload::Schema(schema),
}
}

pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self {
Self {
inner,
payload: DecodedPayload::RecordBatch(batch),
}
}

/// return the metadata field of the inner flight data
pub fn app_metadata(&self) -> &[u8] {
&self.inner.app_metadata
}
}

/// The result of decoding [`FlightData`]
#[derive(Debug)]
pub enum DecodedPayload {
/// None (no data was sent in the corresponding FlightData)
None,

/// A decoded Schema message
Schema(Arc<Schema>),

/// A decoded Record batch.
RecordBatch(RecordBatch),
}
Loading