From da3fca6878ec4a202a5a3b573ce9be9462bcc20a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 24 Dec 2022 07:00:30 -0600 Subject: [PATCH 01/16] Implement `RecordBatch` <--> `FlightData` encode/decode + tests --- arrow-flight/Cargo.toml | 2 + arrow-flight/src/client.rs | 322 +----------------- arrow-flight/src/decode.rs | 396 ++++++++++++++++++++++ arrow-flight/src/encode.rs | 503 ++++++++++++++++++++++++++++ arrow-flight/src/error.rs | 19 ++ arrow-flight/src/lib.rs | 10 +- arrow-flight/tests/client.rs | 92 ++++- arrow-flight/tests/common/server.rs | 46 ++- arrow-flight/tests/encode_decode.rs | 283 ++++++++++++++++ arrow-ipc/src/writer.rs | 5 +- 10 files changed, 1360 insertions(+), 318 deletions(-) create mode 100644 arrow-flight/src/decode.rs create mode 100644 arrow-flight/src/encode.rs create mode 100644 arrow-flight/tests/encode_decode.rs diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 80710d1fac4f..c892a8bcdf9e 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -29,6 +29,8 @@ license = "Apache-2.0" [dependencies] arrow-array = { version = "29.0.0", path = "../arrow-array" } arrow-buffer = { version = "29.0.0", path = "../arrow-buffer" } +# TODO is this needed?? +arrow-cast = { version = "29.0.0", path = "../arrow-cast" } arrow-ipc = { version = "29.0.0", path = "../arrow-ipc" } arrow-schema = { version = "29.0.0", path = "../arrow-schema" } base64 = { version = "0.20", default-features = false, features = ["std"] } diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 0e75ac7c0c7f..31a912535652 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -16,15 +16,13 @@ // under the License. use crate::{ - flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket, + decode::{FlightRecordBatchStream}, + flight_service_client::FlightServiceClient, + FlightDescriptor, FlightInfo, HandshakeRequest, Ticket, }; -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_schema::Schema; use bytes::Bytes; -use futures::{future::ready, ready, stream, StreamExt}; -use std::{collections::HashMap, convert::TryFrom, pin::Pin, sync::Arc, task::Poll}; -use tonic::{metadata::MetadataMap, transport::Channel, Streaming}; +use futures::{future::ready, stream, StreamExt, TryStreamExt}; +use tonic::{metadata::MetadataMap, transport::Channel}; use crate::error::{FlightError, Result}; @@ -161,7 +159,7 @@ impl FlightClient { /// Make a `DoGet` call to the server with the provided ticket, /// returning a [`FlightRecordBatchStream`] for reading - /// [`RecordBatch`]es. + /// [`RecordBatch`](arrow_array::RecordBatch)es. /// /// # Example: /// ```no_run @@ -197,10 +195,15 @@ impl FlightClient { pub async fn do_get(&mut self, ticket: Ticket) -> Result { let request = self.make_request(ticket); - let response = self.inner.do_get(request).await?.into_inner(); + let response_stream = self + .inner + .do_get(request) + .await? + .into_inner() + // convert to FlightError + .map_err(|e| e.into()); - let flight_data_stream = FlightDataStream::new(response); - Ok(FlightRecordBatchStream::new(flight_data_stream)) + Ok(FlightRecordBatchStream::new_from_flight_data(response_stream)) } /// Make a `GetFlightInfo` call to the server with the provided @@ -268,300 +271,3 @@ impl FlightClient { request } } - -/// A stream of [`RecordBatch`]es from from an Arrow Flight server. -/// -/// To access the lower level Flight messages directly, consider -/// calling [`Self::into_inner`] and using the [`FlightDataStream`] -/// directly. -#[derive(Debug)] -pub struct FlightRecordBatchStream { - inner: FlightDataStream, - got_schema: bool, -} - -impl FlightRecordBatchStream { - pub fn new(inner: FlightDataStream) -> Self { - Self { - inner, - got_schema: false, - } - } - - /// Has a message defining the schema been received yet? - pub fn got_schema(&self) -> bool { - self.got_schema - } - - /// Consume self and return the wrapped [`FlightDataStream`] - pub fn into_inner(self) -> FlightDataStream { - self.inner - } -} -impl futures::Stream for FlightRecordBatchStream { - type Item = Result; - - /// Returns the next [`RecordBatch`] available in this stream, or `None` if - /// there are no further results available. - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll>> { - loop { - let res = ready!(self.inner.poll_next_unpin(cx)); - match res { - // Inner exhausted - None => { - return Poll::Ready(None); - } - Some(Err(e)) => { - return Poll::Ready(Some(Err(e))); - } - // translate data - Some(Ok(data)) => match data.payload { - DecodedPayload::Schema(_) if self.got_schema => { - return Poll::Ready(Some(Err(FlightError::protocol( - "Unexpectedly saw multiple Schema messages in FlightData stream", - )))); - } - DecodedPayload::Schema(_) => { - self.got_schema = true; - // Need next message, poll inner again - } - DecodedPayload::RecordBatch(batch) => { - return Poll::Ready(Some(Ok(batch))); - } - DecodedPayload::None => { - // Need next message - } - }, - } - } - } -} - -/// Wrapper around a stream of [`FlightData`] that handles the details -/// of decoding low level Flight messages into [`Schema`] and -/// [`RecordBatch`]es, including details such as dictionaries. -/// -/// # Protocol Details -/// -/// The client handles flight messages as followes: -/// -/// - **None:** This message has no effect. This is useful to -/// transmit metadata without any actual payload. -/// -/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and -/// the decoded schema is returned. -/// -/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing -/// dictionary for the same column will be overwritten. This -/// message is NOT visible. -/// -/// - **Record Batch:** Record batch is created based on the current -/// schema and dictionaries. This fails if no schema was transmitted -/// yet. -/// -/// All other message types (at the time of writing: e.g. tensor and -/// sparse tensor) lead to an error. -/// -/// Example usecases -/// -/// 1. Using this low level stream it is possible to receive a steam -/// of RecordBatches in FlightData that have different schemas by -/// handling multiple schema messages separately. -#[derive(Debug)] -pub struct FlightDataStream { - /// Underlying data stream - response: Streaming, - /// Decoding state - state: Option, - /// seen the end of the inner stream? - done: bool, -} - -impl FlightDataStream { - /// Create a new wrapper around the stream of FlightData - pub fn new(response: Streaming) -> Self { - Self { - state: None, - response, - done: false, - } - } - - /// Extracts flight data from the next message, updating decoding - /// state as necessary. - fn extract_message(&mut self, data: FlightData) -> Result> { - use arrow_ipc::MessageHeader; - let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| { - FlightError::DecodeError(format!("Error decoding root message: {e}")) - })?; - - match message.header_type() { - MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), - MessageHeader::Schema => { - let schema = Schema::try_from(&data).map_err(|e| { - FlightError::DecodeError(format!("Error decoding schema: {e}")) - })?; - - let schema = Arc::new(schema); - let dictionaries_by_field = HashMap::new(); - - self.state = Some(FlightStreamState { - schema: Arc::clone(&schema), - dictionaries_by_field, - }); - Ok(Some(DecodedFlightData::new_schema(data, schema))) - } - MessageHeader::DictionaryBatch => { - let state = if let Some(state) = self.state.as_mut() { - state - } else { - return Err(FlightError::protocol( - "Received DictionaryBatch prior to Schema", - )); - }; - - let buffer: arrow_buffer::Buffer = data.data_body.into(); - let dictionary_batch = - message.header_as_dictionary_batch().ok_or_else(|| { - FlightError::protocol( - "Could not get dictionary batch from DictionaryBatch message", - ) - })?; - - arrow_ipc::reader::read_dictionary( - &buffer, - dictionary_batch, - &state.schema, - &mut state.dictionaries_by_field, - &message.version(), - ) - .map_err(|e| { - FlightError::DecodeError(format!( - "Error decoding ipc dictionary: {e}" - )) - })?; - - // Updated internal state, but no decoded message - Ok(None) - } - MessageHeader::RecordBatch => { - let state = if let Some(state) = self.state.as_ref() { - state - } else { - return Err(FlightError::protocol( - "Received RecordBatch prior to Schema", - )); - }; - - let batch = flight_data_to_arrow_batch( - &data, - Arc::clone(&state.schema), - &state.dictionaries_by_field, - ) - .map_err(|e| { - FlightError::DecodeError(format!( - "Error decoding ipc RecordBatch: {e}" - )) - })?; - - Ok(Some(DecodedFlightData::new_record_batch(data, batch))) - } - other => { - let name = other.variant_name().unwrap_or("UNKNOWN"); - Err(FlightError::protocol(format!("Unexpected message: {name}"))) - } - } - } -} - -impl futures::Stream for FlightDataStream { - type Item = Result; - /// Returns the result of decoding the next [`FlightData`] message - /// from the server, or `None` if there are no further results - /// available. - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if self.done { - return Poll::Ready(None); - } - loop { - let res = ready!(self.response.poll_next_unpin(cx)); - - return Poll::Ready(match res { - None => { - self.done = true; - None // inner is exhausted - } - Some(data) => Some(match data { - Err(e) => Err(FlightError::Tonic(e)), - Ok(data) => match self.extract_message(data) { - Ok(Some(extracted)) => Ok(extracted), - Ok(None) => continue, // Need next input message - Err(e) => Err(e), - }, - }), - }); - } - } -} - -/// tracks the state needed to reconstruct [`RecordBatch`]es from a -/// streaming flight response. -#[derive(Debug)] -struct FlightStreamState { - schema: Arc, - dictionaries_by_field: HashMap, -} - -/// FlightData and the decoded payload (Schema, RecordBatch), if any -#[derive(Debug)] -pub struct DecodedFlightData { - pub inner: FlightData, - pub payload: DecodedPayload, -} - -impl DecodedFlightData { - pub fn new_none(inner: FlightData) -> Self { - Self { - inner, - payload: DecodedPayload::None, - } - } - - pub fn new_schema(inner: FlightData, schema: Arc) -> Self { - Self { - inner, - payload: DecodedPayload::Schema(schema), - } - } - - pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self { - Self { - inner, - payload: DecodedPayload::RecordBatch(batch), - } - } - - /// return the metadata field of the inner flight data - pub fn app_metadata(&self) -> &[u8] { - &self.inner.app_metadata - } -} - -/// The result of decoding [`FlightData`] -#[derive(Debug)] -pub enum DecodedPayload { - /// None (no data was sent in the corresponding FlightData) - None, - - /// A decoded Schema message - Schema(Arc), - - /// A decoded Record batch. - RecordBatch(RecordBatch), -} diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs new file mode 100644 index 000000000000..0b5529c253f4 --- /dev/null +++ b/arrow-flight/src/decode.rs @@ -0,0 +1,396 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{utils::flight_data_to_arrow_batch, FlightData}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::Schema; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; +use std::{ + collections::HashMap, convert::TryFrom, fmt::Debug, pin::Pin, sync::Arc, task::Poll, +}; + +use crate::error::{FlightError, Result}; + +/// Decodes a [Stream] of [`FlightData`] back into +/// [`RecordBatch`]es. This can be used to decode the response from an +/// Arrow Flight server +/// +/// # Note +/// To access the lower level Flight messages (e.g. to access +/// [`FlightData::app_metadata`]), you can call [`Self::into_inner`] +/// and use the [`FlightDataDecoder`] directly. +/// +/// # Example: +/// ```no_run +/// # async fn f() -> Result<(), arrow_flight::error::FlightError>{ +/// # use bytes::Bytes; +/// // make a do_get request +/// use arrow_flight::{ +/// error::Result, +/// decode::FlightRecordBatchStream, +/// Ticket, +/// flight_service_client::FlightServiceClient +/// }; +/// use tonic::transport::Channel; +/// use futures::stream::{StreamExt, TryStreamExt}; +/// +/// let client: FlightServiceClient = // make client.. +/// # unimplemented!(); +/// +/// let request = tonic::Request::new( +/// Ticket { ticket: Bytes::new() } +/// ); +/// +/// // Get a stream of FlightData; +/// let flight_data_stream = client +/// .do_get(request) +/// .await? +/// .into_inner(); +/// +/// // Decode stream of FlightData to RecordBatches +/// let record_batch_stream = FlightRecordBatchStream::new_from_flight_data( +/// // convert tonic::Status to FlightError +/// flight_data_stream.map_err(|e| e.into()) +/// ); +/// +/// // Read back RecordBatches +/// while let Some(batch) = record_batch_stream.next().await { +/// match batch { +/// Ok(batch) => { /* process batch */ }, +/// Err(e) => { /* handle error */ }, +/// }; +/// } +/// +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug)] +pub struct FlightRecordBatchStream { + inner: FlightDataDecoder, + got_schema: bool, +} + +impl FlightRecordBatchStream { + /// Create a new [`FlightRecordBatchStream`] from a decoded stream + pub fn new(inner: FlightDataDecoder) -> Self { + Self { + inner, + got_schema: false, + } + } + + /// Create a new [`FlightRecordBatchStream`] from a stream of [`FlightData`] + pub fn new_from_flight_data(inner: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + inner: FlightDataDecoder::new(inner), + got_schema: false, + } + } + + /// Has a message defining the schema been received yet? + pub fn got_schema(&self) -> bool { + self.got_schema + } + + /// Consume self and return the wrapped [`FlightDataDecoder`] + pub fn into_inner(self) -> FlightDataDecoder { + self.inner + } +} +impl futures::Stream for FlightRecordBatchStream { + type Item = Result; + + /// Returns the next [`RecordBatch`] available in this stream, or `None` if + /// there are no further results available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll>> { + loop { + let res = ready!(self.inner.poll_next_unpin(cx)); + match res { + // Inner exhausted + None => { + return Poll::Ready(None); + } + Some(Err(e)) => { + return Poll::Ready(Some(Err(e))); + } + // translate data + Some(Ok(data)) => match data.payload { + DecodedPayload::Schema(_) if self.got_schema => { + return Poll::Ready(Some(Err(FlightError::protocol( + "Unexpectedly saw multiple Schema messages in FlightData stream", + )))); + } + DecodedPayload::Schema(_) => { + self.got_schema = true; + // Need next message, poll inner again + } + DecodedPayload::RecordBatch(batch) => { + return Poll::Ready(Some(Ok(batch))); + } + DecodedPayload::None => { + // Need next message + } + }, + } + } + } +} + +/// Wrapper around a stream of [`FlightData`] that handles the details +/// of decoding low level Flight messages into [`Schema`] and +/// [`RecordBatch`]es, including details such as dictionaries. +/// +/// # Protocol Details +/// +/// The client handles flight messages as followes: +/// +/// - **None:** This message has no effect. This is useful to +/// transmit metadata without any actual payload. +/// +/// - **Schema:** The schema is (re-)set. Dictionaries are cleared and +/// the decoded schema is returned. +/// +/// - **Dictionary Batch:** A new dictionary for a given column is registered. An existing +/// dictionary for the same column will be overwritten. This +/// message is NOT visible. +/// +/// - **Record Batch:** Record batch is created based on the current +/// schema and dictionaries. This fails if no schema was transmitted +/// yet. +/// +/// All other message types (at the time of writing: e.g. tensor and +/// sparse tensor) lead to an error. +/// +/// Example usecases +/// +/// 1. Using this low level stream it is possible to receive a steam +/// of RecordBatches in FlightData that have different schemas by +/// handling multiple schema messages separately. +pub struct FlightDataDecoder { + /// Underlying data stream + response: BoxStream<'static, Result>, + /// Decoding state + state: Option, + /// seen the end of the inner stream? + done: bool, +} + +impl Debug for FlightDataDecoder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FlightDataDecoder") + .field("response", &"") + .field("state", &self.state) + .field("done", &self.done) + .finish() + } +} + +impl FlightDataDecoder { + /// Create a new wrapper around the stream of [`FlightData`] + pub fn new(response: S) -> Self + where + S: Stream> + Send + 'static, + { + Self { + state: None, + response: response.boxed(), + done: false, + } + } + + /// Extracts flight data from the next message, updating decoding + /// state as necessary. + fn extract_message(&mut self, data: FlightData) -> Result> { + use arrow_ipc::MessageHeader; + let message = arrow_ipc::root_as_message(&data.data_header[..]).map_err(|e| { + FlightError::DecodeError(format!("Error decoding root message: {e}")) + })?; + + match message.header_type() { + MessageHeader::NONE => Ok(Some(DecodedFlightData::new_none(data))), + MessageHeader::Schema => { + let schema = Schema::try_from(&data).map_err(|e| { + FlightError::DecodeError(format!("Error decoding schema: {e}")) + })?; + + let schema = Arc::new(schema); + let dictionaries_by_field = HashMap::new(); + + self.state = Some(FlightStreamState { + schema: Arc::clone(&schema), + dictionaries_by_field, + }); + Ok(Some(DecodedFlightData::new_schema(data, schema))) + } + MessageHeader::DictionaryBatch => { + let state = if let Some(state) = self.state.as_mut() { + state + } else { + return Err(FlightError::protocol( + "Received DictionaryBatch prior to Schema", + )); + }; + + let buffer: arrow_buffer::Buffer = data.data_body.into(); + let dictionary_batch = + message.header_as_dictionary_batch().ok_or_else(|| { + FlightError::protocol( + "Could not get dictionary batch from DictionaryBatch message", + ) + })?; + + arrow_ipc::reader::read_dictionary( + &buffer, + dictionary_batch, + &state.schema, + &mut state.dictionaries_by_field, + &message.version(), + ) + .map_err(|e| { + FlightError::DecodeError(format!( + "Error decoding ipc dictionary: {e}" + )) + })?; + + // Updated internal state, but no decoded message + Ok(None) + } + MessageHeader::RecordBatch => { + let state = if let Some(state) = self.state.as_ref() { + state + } else { + return Err(FlightError::protocol( + "Received RecordBatch prior to Schema", + )); + }; + + let batch = flight_data_to_arrow_batch( + &data, + Arc::clone(&state.schema), + &state.dictionaries_by_field, + ) + .map_err(|e| { + FlightError::DecodeError(format!( + "Error decoding ipc RecordBatch: {e}" + )) + })?; + + Ok(Some(DecodedFlightData::new_record_batch(data, batch))) + } + other => { + let name = other.variant_name().unwrap_or("UNKNOWN"); + Err(FlightError::protocol(format!("Unexpected message: {name}"))) + } + } + } +} + +impl futures::Stream for FlightDataDecoder { + type Item = Result; + /// Returns the result of decoding the next [`FlightData`] message + /// from the server, or `None` if there are no further results + /// available. + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + if self.done { + return Poll::Ready(None); + } + loop { + let res = ready!(self.response.poll_next_unpin(cx)); + + return Poll::Ready(match res { + None => { + self.done = true; + None // inner is exhausted + } + Some(data) => Some(match data { + Err(e) => Err(e), + Ok(data) => match self.extract_message(data) { + Ok(Some(extracted)) => Ok(extracted), + Ok(None) => continue, // Need next input message + Err(e) => Err(e), + }, + }), + }); + } + } +} + +/// tracks the state needed to reconstruct [`RecordBatch`]es from a +/// streaming flight response. +#[derive(Debug)] +struct FlightStreamState { + schema: Arc, + dictionaries_by_field: HashMap, +} + +/// FlightData and the decoded payload (Schema, RecordBatch), if any +#[derive(Debug)] +pub struct DecodedFlightData { + pub inner: FlightData, + pub payload: DecodedPayload, +} + +impl DecodedFlightData { + pub fn new_none(inner: FlightData) -> Self { + Self { + inner, + payload: DecodedPayload::None, + } + } + + pub fn new_schema(inner: FlightData, schema: Arc) -> Self { + Self { + inner, + payload: DecodedPayload::Schema(schema), + } + } + + pub fn new_record_batch(inner: FlightData, batch: RecordBatch) -> Self { + Self { + inner, + payload: DecodedPayload::RecordBatch(batch), + } + } + + /// return the metadata field of the inner flight data + pub fn app_metadata(&self) -> Bytes { + self.inner.app_metadata.clone() + } +} + +/// The result of decoding [`FlightData`] +#[derive(Debug)] +pub enum DecodedPayload { + /// None (no data was sent in the corresponding FlightData) + None, + + /// A decoded Schema message + Schema(Arc), + + /// A decoded Record batch. + RecordBatch(RecordBatch), +} diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs new file mode 100644 index 000000000000..a946eefcf6d8 --- /dev/null +++ b/arrow-flight/src/encode.rs @@ -0,0 +1,503 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; + +use crate::{error::FlightError, error::Result, FlightData, SchemaAsIpc}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use bytes::Bytes; +use futures::{ready, stream::BoxStream, Stream, StreamExt}; + +/// Creates a [`Stream`](futures::Stream) of [`FlightData`]s from a +/// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>. +/// +/// This can be used to implement [`FlightService::do_get`] in an +/// Arrow Flight implementation; +/// +/// # Caveats +/// 1. [`DictionaryArray`](arrow_array::array::DictionaryArray)s +/// are converted to their underlying types prior to transport, due to +/// . +/// +/// # Example +/// ```no_run +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let record_batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get an input stream of Result +/// let input_stream = futures::stream::iter(vec![Ok(record_batch)]); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// +/// // Create a tonic `Response` that can be returned from a Flight server +/// let response = tonic::Response::new(flight_data_stream); +/// # } +/// ``` +/// +/// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get +#[derive(Debug)] +pub struct FlightDataEncoderBuilder { + /// The maximum message size (see details on [`Self::with_max_message_size`]). + max_batch_size: usize, + /// Ipc writer options + options: IpcWriteOptions, + /// Metadata to add to the schema message + app_metadata: Bytes, +} + +/// Default target size for record batches to send. +/// +/// Note this value would normally be 4MB, but the size calculation is +/// somehwhat inexact, so we set it to 2MB. +pub const GRPC_TARGET_MAX_BATCH_SIZE: usize = 2097152; + +impl Default for FlightDataEncoderBuilder { + fn default() -> Self { + Self { + max_batch_size: GRPC_TARGET_MAX_BATCH_SIZE, + options: IpcWriteOptions::default(), + app_metadata: Bytes::new(), + } + } +} + +impl FlightDataEncoderBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Set the (approximate) maximum encoded [`RecordBatch`] size to + /// limit the gRPC message size. Defaults fo 2MB. + /// + /// The encoder splits up [`RecordBatch`]s (preserving order) to + /// limit individual messages to approximately this size. The size + /// is approximate because there additional encoding overhead on + /// top of the underlying data itself. + /// + pub fn with_max_message_size(mut self, max_batch_size: usize) -> Self { + self.max_batch_size = max_batch_size; + self + } + + /// Specfy application specific metadata included in the + /// [`FlightData::app_metadata`] field of the the first Schema + /// message + pub fn with_metadata(mut self, app_metadata: Bytes) -> Self { + self.app_metadata = app_metadata; + self + } + + /// Set the [`IpcWriteOptions`] used to encode the [`RecordBatch`]es for transport. + pub fn with_options(mut self, options: IpcWriteOptions) -> Self { + self.options = options; + self + } + + /// Return a [`Stream`](futures::Stream) of [`FlightData`], + /// consuming self. More details on [`FlightDataEncoderBuilder`] + pub fn build(self, input: S) -> FlightDataEncoder + where + S: Stream> + Send + 'static, + { + let Self { + max_batch_size, + options, + app_metadata, + } = self; + + FlightDataEncoder::new(input.boxed(), max_batch_size, options, app_metadata) + } +} + +/// Stream that encodes a stream of record batches to flight data. +/// +/// See [`FlightDataEncoderBuilder`] for details and example. +pub struct FlightDataEncoder { + /// Input stream + inner: BoxStream<'static, Result>, + /// schema, set after the first batch + schema: Option, + /// Max sixe of batches to encode + max_batch_size: usize, + /// do the encoding / tracking of dictionaries + encoder: FlightIpcEncoder, + /// optional metadata to add to schema FlightData + app_metadata: Option, + /// data queued up to send but not yet sent + queue: VecDeque, + /// Is this strema done (inner is empty or errored) + done: bool, +} + +impl FlightDataEncoder { + fn new( + inner: BoxStream<'static, Result>, + max_batch_size: usize, + options: IpcWriteOptions, + app_metadata: Bytes, + ) -> Self { + Self { + inner, + schema: None, + max_batch_size, + encoder: FlightIpcEncoder::new(options), + app_metadata: Some(app_metadata), + queue: VecDeque::new(), + done: false, + } + } + + /// Place the `FlightData` in the queue to send + fn queue_message(&mut self, data: FlightData) { + self.queue.push_back(data); + } + + /// Place the `FlightData` in the queue to send + fn queue_messages(&mut self, datas: Vec) { + for data in datas { + self.queue_message(data) + } + } + + /// Encodes batch into one or more `FlightData` messages in self.queue + fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> { + let schema = match self.schema.take() { + Some(schema) => schema, + None => { + let batch_schema = batch.schema(); + // The first message is the schema message, and all + // batches have the same schema + let schema = Arc::new(prepare_schema_for_flight(&batch_schema)); + let mut schema_flight_data = self.encoder.encode_schema(&schema); + + // attach any metadata requested + if let Some(app_metadata) = self.app_metadata.take() { + schema_flight_data.app_metadata = app_metadata; + } + self.queue_message(schema_flight_data); + schema + } + }; + + // remember schema + self.schema = Some(schema.clone()); + + // encode the batch + let batch = prepare_batch_for_flight(&batch, schema)?; + + for batch in split_batch_for_grpc_response(batch, self.max_batch_size) { + let (flight_dictionaries, flight_batch) = + self.encoder.encode_batch(&batch)?; + + self.queue_messages(flight_dictionaries); + self.queue_message(flight_batch); + } + + Ok(()) + } +} + +impl Stream for FlightDataEncoder { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + loop { + if self.done && self.queue.is_empty() { + return Poll::Ready(None); + } + + // Any messages queued to send? + if let Some(data) = self.queue.pop_front() { + return Poll::Ready(Some(Ok(data))); + } + + // Get next batch + let batch = ready!(self.inner.poll_next_unpin(cx)); + + match batch { + None => { + // inner is done + self.done = true; + } + Some(Err(e)) => { + // error from inner + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + Some(Ok(batch)) => { + // had data, encode into the queue + if let Err(e) = self.encode_batch(batch) { + self.done = true; + self.queue.clear(); + return Poll::Ready(Some(Err(e))); + } + } + } + } + } +} + +/// Prepare an arrow Schema for transport over the Arrow Flight protocol +/// +/// Convert dictionary types to underlying types +/// +/// See hydrate_dictionary for more information +pub fn prepare_schema_for_flight(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.clone(), + }) + .collect(); + + Schema::new(fields) +} + +/// Split [`RecordBatch`] so it hopefully fits into a gRPC response. +/// +/// Data is zero-copy sliced into batches. +pub fn split_batch_for_grpc_response( + batch: RecordBatch, + max_batch_size: usize, +) -> Vec { + let size = batch + .columns() + .iter() + .map(|col| col.get_array_memory_size()) + .sum::(); + + let n_batches = + (size / max_batch_size + usize::from(size % max_batch_size != 0)).max(1); + let rows_per_batch = (batch.num_rows() / n_batches).max(1); + let mut out = Vec::with_capacity(n_batches + 1); + + let mut offset = 0; + while offset < batch.num_rows() { + let length = (rows_per_batch).min(batch.num_rows() - offset); + out.push(batch.slice(offset, length)); + + offset += length; + } + + out +} + +/// The data needed to encode a stream of flight data, holding on to +/// shared Dictionaries. +/// +/// TODO: at allow dictionaries to be flushed / avoid building them +/// +/// TODO limit on the number of dictionaries??? +struct FlightIpcEncoder { + options: IpcWriteOptions, + data_gen: IpcDataGenerator, + dictionary_tracker: DictionaryTracker, +} + +impl FlightIpcEncoder { + fn new(options: IpcWriteOptions) -> Self { + let error_on_replacement = true; + Self { + options, + data_gen: IpcDataGenerator::default(), + dictionary_tracker: DictionaryTracker::new(error_on_replacement), + } + } + + /// Encode a schema as a FlightData + fn encode_schema(&self, schema: &Schema) -> FlightData { + SchemaAsIpc::new(schema, &self.options).into() + } + + /// Convert a `RecordBatch` to a Vec of `FlightData` representing + /// dictionaries and a `FlightData` representing the batch + fn encode_batch( + &mut self, + batch: &RecordBatch, + ) -> Result<(Vec, FlightData)> { + let (encoded_dictionaries, encoded_batch) = self + .data_gen + .encoded_batch(batch, &mut self.dictionary_tracker, &self.options) + .map_err(FlightError::Arrow)?; + + let flight_dictionaries = + encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + Ok((flight_dictionaries, flight_batch)) + } +} + +/// Prepares a RecordBatch for transport over the Arrow Flight protocol +/// +/// This means: +/// +/// 1. Hydrates any dictionaries to its underlying type. See +/// hydrate_dictionary for more information. +/// +pub fn prepare_batch_for_flight( + batch: &RecordBatch, + schema: SchemaRef, +) -> Result { + let columns = batch + .columns() + .iter() + .map(hydrate_dictionary) + .collect::>>()?; + + RecordBatch::try_new(schema, columns).map_err(FlightError::Arrow) +} + +/// Hydrates a dictionary to its underlying type +/// +/// An IPC response, streaming or otherwise, defines its schema up front +/// which defines the mapping from dictionary IDs. It then sends these +/// dictionaries over the wire. +/// +/// This requires identifying the different dictionaries in use, assigning +/// them IDs, and sending new dictionaries, delta or otherwise, when needed +/// +/// See also: +/// * +/// +/// For now we just hydrate the dictionaries to their underlying type +fn hydrate_dictionary(array: &ArrayRef) -> Result { + if let DataType::Dictionary(_, value) = array.data_type() { + arrow_cast::cast(array, value).map_err(FlightError::Arrow) + } else { + Ok(Arc::clone(array)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{UInt32Array, UInt8Array}, + compute::concat_batches, + }; + + use super::*; + + #[test] + /// ensure only the batch's used data (not the allocated data) is sent + /// + fn test_encode_flight_data() { + let options = arrow::ipc::writer::IpcWriteOptions::default(); + let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) + .expect("cannot create record batch"); + let schema = batch.schema(); + + let (_, baseline_flight_batch) = make_flight_data(&batch, &options); + + let big_batch = batch.slice(0, batch.num_rows() - 1); + let optimized_big_batch = + prepare_batch_for_flight(&big_batch, Arc::clone(&schema)) + .expect("failed to optimize"); + let (_, optimized_big_flight_batch) = + make_flight_data(&optimized_big_batch, &options); + + assert_eq!( + baseline_flight_batch.data_body.len(), + optimized_big_flight_batch.data_body.len() + ); + + let small_batch = batch.slice(0, 1); + let optimized_small_batch = + prepare_batch_for_flight(&small_batch, Arc::clone(&schema)) + .expect("failed to optimize"); + let (_, optimized_small_flight_batch) = + make_flight_data(&optimized_small_batch, &options); + + assert!( + baseline_flight_batch.data_body.len() + > optimized_small_flight_batch.data_body.len() + ); + } + + pub fn make_flight_data( + batch: &RecordBatch, + options: &IpcWriteOptions, + ) -> (Vec, FlightData) { + let data_gen = IpcDataGenerator::default(); + let mut dictionary_tracker = DictionaryTracker::new(false); + + let (encoded_dictionaries, encoded_batch) = data_gen + .encoded_batch(batch, &mut dictionary_tracker, options) + .expect("DictionaryTracker configured above to not error on replacement"); + + let flight_dictionaries = + encoded_dictionaries.into_iter().map(Into::into).collect(); + let flight_batch = encoded_batch.into(); + + (flight_dictionaries, flight_batch) + } + + #[test] + fn test_split_batch_for_grpc_response() { + let max_batch_size = 1024; + + // no split + let c = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_batch_size); + assert_eq!(split.len(), 1); + assert_eq!(batch, split[0]); + + // split once + let n_rows = max_batch_size + 1; + assert!(n_rows % 2 == 1, "should be an odd number"); + let c = + UInt8Array::from((0..n_rows).map(|i| (i % 256) as u8).collect::>()); + let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c) as ArrayRef)]) + .expect("cannot create record batch"); + let split = split_batch_for_grpc_response(batch.clone(), max_batch_size); + assert_eq!(split.len(), 3); + assert_eq!( + split.iter().map(|batch| batch.num_rows()).sum::(), + n_rows + ); + assert_eq!(concat_batches(&batch.schema(), &split).unwrap(), batch); + } + + // test sending record batches + // test sending record batches with multiple different dictionaries +} diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs index fbb9efa44c24..7395c3362a83 100644 --- a/arrow-flight/src/error.rs +++ b/arrow-flight/src/error.rs @@ -15,9 +15,13 @@ // specific language governing permissions and limitations // under the License. +use arrow_schema::ArrowError; + /// Errors for the Apache Arrow Flight crate #[derive(Debug)] pub enum FlightError { + /// Underlying arrow error + Arrow(ArrowError), /// Returned when functionality is not yet available. NotYetImplemented(String), /// Error from the underlying tonic library @@ -56,4 +60,19 @@ impl From for FlightError { } } +// default conversion from FlightError to tonic treats everything +// other than `Status` as an internal error +impl From for tonic::Status { + fn from(value: FlightError) -> Self { + match value { + FlightError::Arrow(e) => tonic::Status::internal(e.to_string()), + FlightError::NotYetImplemented(e) => tonic::Status::internal(e), + FlightError::Tonic(status) => status, + FlightError::ProtocolError(e) => tonic::Status::internal(e), + FlightError::DecodeError(e) => tonic::Status::internal(e), + FlightError::ExternalError(e) => tonic::Status::internal(e.to_string()), + } + } +} + pub type Result = std::result::Result; diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index f30cb54844da..c2da58eb5bb7 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -71,10 +71,18 @@ pub mod flight_service_server { pub use gen::flight_service_server::FlightServiceServer; } -/// Mid Level [`FlightClient`] for +/// Mid Level [`FlightClient`] pub mod client; pub use client::FlightClient; +/// Decoder to create [`RecordBatch`](arrow_array::RecordBatch) streams from [`FlightData`] streams. +/// See [`FlightRecordBatchStream`](decode::FlightRecordBatchStream). +pub mod decode; + +/// Encoder to create [`FlightData`] streams from [`RecordBatch`](arrow_array::RecordBatch) streams. +/// See [`FlightDataEncoderBuilder`](encode::FlightDataEncoderBuilder). +pub mod encode; + /// Common error types pub mod error; diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 5bc1062f046d..c471294d7dc4 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -20,20 +20,21 @@ mod common { pub mod server; } +use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ error::FlightError, FlightClient, FlightDescriptor, FlightInfo, HandshakeRequest, - HandshakeResponse, + HandshakeResponse, Ticket, }; use bytes::Bytes; use common::server::TestFlightServer; -use futures::Future; +use futures::{Future, TryStreamExt}; use tokio::{net::TcpListener, task::JoinHandle}; use tonic::{ transport::{Channel, Uri}, Status, }; -use std::{net::SocketAddr, time::Duration}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; const DEFAULT_TIMEOUT_SECONDS: u64 = 30; @@ -173,7 +174,90 @@ async fn test_get_flight_info_metadata() { // TODO more negative tests (like if there are endpoints defined, etc) -// TODO test for do_get +#[tokio::test] +async fn test_do_get() { + do_test(|test_server, mut client| async move { + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let response = vec![Ok(batch.clone())]; + test_server.set_do_get_response(response); + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + let expected_response = vec![batch]; + let response: Vec<_> = response_stream + .try_collect() + .await + .expect("Error streaming data"); + + assert_eq!(response, expected_response); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error() { + do_test(|test_server, mut client| async move { + client.add_header("foo", "bar").unwrap(); + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let response = client.do_get(ticket.clone()).await.unwrap_err(); + + let e = Status::internal("No do_get response configured"); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + ensure_metadata(&client, &test_server); + }) + .await; +} + +#[tokio::test] +async fn test_do_get_error_in_record_batch_stream() { + do_test(|test_server, mut client| async move { + let ticket = Ticket { + ticket: Bytes::from("my awesome flight ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let e = Status::data_loss("she's dead jim"); + + let expected_response = vec![Ok(batch), Err(FlightError::Tonic(e.clone()))]; + + test_server.set_do_get_response(expected_response); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making request"); + + let response: Result, FlightError> = response_stream.try_collect().await; + + let response = response.unwrap_err(); + expect_status(response, e); + // server still got the request + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + }) + .await; +} /// Runs the future returned by the function, passing it a test server and client async fn do_test(f: F) diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index f1cb140b68c7..45f81b189e8d 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -17,10 +17,13 @@ use std::sync::{Arc, Mutex}; -use futures::stream::BoxStream; +use arrow_array::RecordBatch; +use futures::{stream::BoxStream, TryStreamExt}; use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; use arrow_flight::{ + encode::FlightDataEncoderBuilder, + error::FlightError, flight_service_server::{FlightService, FlightServiceServer}, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, @@ -80,6 +83,21 @@ impl TestFlightServer { .take() } + /// Specify the response returned from the next call to `do_get` + pub fn set_do_get_response(&self, response: Vec>) { + let mut state = self.state.lock().expect("mutex not poisoned"); + state.do_get_response.replace(response); + } + + /// Take and return last do_get request send to the server, + pub fn take_do_get_request(&self) -> Option { + self.state + .lock() + .expect("mutex not poisoned") + .do_get_request + .take() + } + /// Returns the last metadata from a request received by the server pub fn take_last_request_metadata(&self) -> Option { self.state @@ -97,7 +115,7 @@ impl TestFlightServer { } } -/// mutable state for the TestFlightSwrver +/// mutable state for the TestFlightServer, captures requests and provides responses #[derive(Debug, Default)] struct State { /// The last handshake request that was received @@ -108,6 +126,10 @@ struct State { pub get_flight_info_request: Option, /// the next response to return from `get_flight_info` pub get_flight_info_response: Option>, + /// The last do_get request received + pub do_get_request: Option, + /// The next response returned from `do_get` + pub do_get_response: Option>>, /// The last request headers received pub last_request_metadata: Option, } @@ -177,9 +199,25 @@ impl FlightService for TestFlightServer { async fn do_get( &self, - _request: Request, + request: Request, ) -> Result, Status> { - Err(Status::unimplemented("Implement do_get")) + self.save_metadata(&request); + let mut state = self.state.lock().expect("mutex not poisoned"); + + state.do_get_request = Some(request.into_inner()); + + let batches: Vec<_> = state + .do_get_response + .take() + .ok_or_else(|| Status::internal("No do_get response configured"))?; + + let batch_stream = futures::stream::iter(batches); + + let stream = FlightDataEncoderBuilder::new() + .build(batch_stream) + .map_err(|e| e.into()); + + Ok(Response::new(Box::pin(stream) as _)) } async fn do_put( diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs new file mode 100644 index 000000000000..29afa9bead3e --- /dev/null +++ b/arrow-flight/tests/encode_decode.rs @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for round trip encoding / decoding + +use std::sync::Arc; + +use arrow::{compute::concat_batches, datatypes::Int32Type}; +use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; +use arrow_flight::{ + decode::{DecodedPayload, FlightRecordBatchStream}, + encode::{ + prepare_batch_for_flight, prepare_schema_for_flight, FlightDataEncoderBuilder, + }, + error::FlightError, +}; +use bytes::Bytes; +use futures::{StreamExt, TryStreamExt}; + +#[tokio::test] +async fn test_empty() { + roundtrip(vec![]).await; +} + +#[tokio::test] +async fn test_empty_batch() { + let batch = make_primative_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + roundtrip(vec![empty]).await; +} + +#[tokio::test] +async fn test_error() { + let input_batch_stream = + futures::stream::iter(vec![Err(FlightError::NotYetImplemented("foo".into()))]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, _> = decode_stream.try_collect().await; + + let result = result.unwrap_err(); + assert_eq!(result.to_string(), r#"NotYetImplemented("foo")"#); +} + +#[tokio::test] +async fn test_primative_one() { + roundtrip(vec![make_primative_batch(5)]).await; +} + +#[tokio::test] +async fn test_primative_many() { + roundtrip(vec![ + make_primative_batch(1), + make_primative_batch(7), + make_primative_batch(32), + ]) + .await; +} + +#[tokio::test] +async fn test_primative_empty() { + let batch = make_primative_batch(5); + let empty = RecordBatch::new_empty(batch.schema()); + + roundtrip(vec![batch, empty]).await; +} + +#[tokio::test] +async fn test_dictionary_one() { + roundtrip_dictionary(vec![make_dictionary_batch(5)]).await; +} + +#[tokio::test] +async fn test_dictionary_many() { + roundtrip_dictionary(vec![ + make_dictionary_batch(5), + make_dictionary_batch(9), + make_dictionary_batch(5), + make_dictionary_batch(5), + ]) + .await; +} + +#[tokio::test] +async fn test_app_metadata() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primative_batch(78))]); + + let app_metadata = Bytes::from("My Metadata"); + let encoder = FlightDataEncoderBuilder::default().with_metadata(app_metadata.clone()); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = + FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let mut messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + // expect that the app metadata made it through on the schema message + assert_eq!(messages.len(), 2); + let message2 = messages.pop().unwrap(); + let message1 = messages.pop().unwrap(); + + assert_eq!(message1.app_metadata(), app_metadata); + assert!(matches!(message1.payload, DecodedPayload::Schema(_))); + + // but not on the data + assert_eq!(message2.app_metadata(), Bytes::new()); + assert!(matches!(message2.payload, DecodedPayload::RecordBatch(_))); +} + +#[tokio::test] +async fn test_max_message_size() { + let input_batch_stream = futures::stream::iter(vec![Ok(make_primative_batch(5))]); + + // 5 input rows, with a very small limit should result in 5 batch messages + let encoder = FlightDataEncoderBuilder::default().with_max_message_size(1); + + let encode_stream = encoder.build(input_batch_stream); + + // use lower level stream to get access to app metadata + let decode_stream = + FlightRecordBatchStream::new_from_flight_data(encode_stream).into_inner(); + + let messages: Vec<_> = decode_stream.try_collect().await.expect("encode fails"); + + println!("{messages:#?}"); + + assert_eq!(messages.len(), 6); + assert!(matches!(messages[0].payload, DecodedPayload::Schema(_))); + for message in messages.iter().skip(1) { + assert!(matches!(message.payload, DecodedPayload::RecordBatch(_))); + } +} + +#[tokio::test] +async fn test_max_message_size_fuzz() { + // send through batches of varying sizes with various max + // batch sizes and ensure the data gets through ok + let input = vec![ + make_primative_batch(123), + make_primative_batch(17), + make_primative_batch(201), + make_primative_batch(2), + make_primative_batch(1), + make_primative_batch(11), + make_primative_batch(127), + ]; + + for max_message_size in [10, 1024, 2048, 6400, 3211212] { + let encoder = + FlightDataEncoderBuilder::default().with_max_message_size(max_message_size); + + let input_batch_stream = futures::stream::iter(input.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output: Vec<_> = decode_stream.try_collect().await.expect("encode / decode"); + + let input_batch = concat_batches(&input[0].schema(), &input).unwrap(); + let output_batch = concat_batches(&output[0].schema(), &output).unwrap(); + assert_eq!(input_batch, output_batch); + } +} + +/// Make a primtive batch for testing +/// +/// Example: +/// i: 0, 1, None, 3, 4 +/// f: 5.0, 4.0, None, 2.0, 1.0 +fn make_primative_batch(num_rows: usize) -> RecordBatch { + let i: UInt8Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some(i.try_into().unwrap()) + } + }) + .collect(); + + let f: Float64Array = (0..num_rows) + .map(|i| { + if i == num_rows / 2 { + None + } else { + Some((num_rows - i) as f64) + } + }) + .collect(); + + RecordBatch::try_from_iter(vec![("i", Arc::new(i) as ArrayRef), ("f", Arc::new(f))]) + .unwrap() +} + +fn make_dictionary_batch(num_rows: usize) -> RecordBatch { + let values: Vec<_> = (0..num_rows) + .map(|i| { + if i == i / 2 { + None + } else { + // repeat some values for low cardinality + let v = i / 3; + Some(format!("value{v}")) + } + }) + .collect(); + + let a: DictionaryArray = values + .iter() + .map(|s| s.as_ref().map(|s| s.as_str())) + .collect(); + + RecordBatch::try_from_iter(vec![("a", Arc::new(a) as ArrayRef)]).unwrap() +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and valides the decoded record batches +/// match the input. +async fn roundtrip(input: Vec) { + let expected_output = input.clone(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output) + .await +} + +/// Encodes input as a FlightData stream, and then decodes it using +/// FlightRecordBatchStream and valides the decoded record batches +/// match the expected input. +/// +/// When is resolved, +/// it should be possible to use `roundtrip` +async fn roundtrip_dictionary(input: Vec) { + let schema = Arc::new(prepare_schema_for_flight(&input[0].schema())); + let expected_output: Vec<_> = input + .iter() + .map(|batch| prepare_batch_for_flight(batch, schema.clone()).unwrap()) + .collect(); + roundtrip_with_encoder(FlightDataEncoderBuilder::default(), input, expected_output) + .await +} + +async fn roundtrip_with_encoder( + encoder: FlightDataEncoderBuilder, + input_batches: Vec, + expected_batches: Vec, +) { + println!("Round tripping with encoder:\n{encoder:#?}"); + + let input_batch_stream = futures::stream::iter(input_batches.clone()).map(Ok); + + let encode_stream = encoder.build(input_batch_stream); + + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let output_batches: Vec<_> = + decode_stream.try_collect().await.expect("encode / decode"); + + // remove any empty batches from input as they are not transmitted + let expected_batches: Vec<_> = expected_batches + .into_iter() + .filter(|b| b.num_rows() > 0) + .collect(); + + assert_eq!(expected_batches, output_batches); +} diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 106b4e4c9850..82cf2c90b852 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -37,7 +37,7 @@ use arrow_schema::*; use crate::compression::CompressionCodec; use crate::CONTINUATION_MARKER; -/// IPC write options used to control the behaviour of the writer +/// IPC write options used to control the behaviour of the [`IpcDataGenerator`] #[derive(Debug, Clone)] pub struct IpcWriteOptions { /// Write padding after memory buffers to this multiple of bytes. @@ -514,6 +514,9 @@ pub struct DictionaryTracker { } impl DictionaryTracker { + /// Create a new [`DictionaryTracker`]. If `error_on_replacement` + /// is true, an error will be generated if an update to an + /// existing dictionary is attempted. pub fn new(error_on_replacement: bool) -> Self { Self { written: HashMap::new(), From ad3c1fe6b52de131082568b7001c6ba136337c5a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 25 Dec 2022 16:35:09 -0600 Subject: [PATCH 02/16] fix comment --- arrow-flight/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index c892a8bcdf9e..0e8072f5631c 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -29,7 +29,7 @@ license = "Apache-2.0" [dependencies] arrow-array = { version = "29.0.0", path = "../arrow-array" } arrow-buffer = { version = "29.0.0", path = "../arrow-buffer" } -# TODO is this needed?? +# Cast is needed to work around https://github.com/apache/arrow-rs/issues/3389 arrow-cast = { version = "29.0.0", path = "../arrow-cast" } arrow-ipc = { version = "29.0.0", path = "../arrow-ipc" } arrow-schema = { version = "29.0.0", path = "../arrow-schema" } From feeedf0976907ac913708a53240ce692a3b889d9 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 09:09:44 -0500 Subject: [PATCH 03/16] Update arrow-flight/src/encode.rs Co-authored-by: Liang-Chi Hsieh --- arrow-flight/src/encode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index a946eefcf6d8..70147e606f99 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -119,7 +119,7 @@ impl FlightDataEncoderBuilder { } /// Return a [`Stream`](futures::Stream) of [`FlightData`], - /// consuming self. More details on [`FlightDataEncoderBuilder`] + /// consuming self. More details on [`FlightDataEncoder`] pub fn build(self, input: S) -> FlightDataEncoder where S: Stream> + Send + 'static, From 4671454bc2a46bd1d05b7737ef615174ff78ebd6 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 08:31:34 -0600 Subject: [PATCH 04/16] Add test encoding error --- arrow-flight/tests/encode_decode.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 29afa9bead3e..ba3f7d3460fa 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -182,6 +182,27 @@ async fn test_max_message_size_fuzz() { } } + +#[tokio::test] +async fn test_mismatched_record_batch_schema() { + // send 2 batches with different schemas + let input_batch_stream = futures::stream::iter(vec![ + Ok(make_primative_batch(5)), + Ok(make_dictionary_batch(3)), + ]); + + let encoder = FlightDataEncoderBuilder::default(); + let encode_stream = encoder.build(input_batch_stream); + + let result: Result, FlightError> = encode_stream.try_collect().await; + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "Arrow(InvalidArgumentError(\"number of columns(1) must match number of fields(2) in schema\"))" + ); +} + + /// Make a primtive batch for testing /// /// Example: From 69f32a92c974c5557a0bb4bc5f039806b4dbb65b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 08:45:38 -0600 Subject: [PATCH 05/16] Add test for chained streams --- arrow-flight/src/client.rs | 7 ++-- arrow-flight/tests/encode_decode.rs | 63 ++++++++++++++++++++++++++++- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 31a912535652..753c40f2a5c1 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -16,8 +16,7 @@ // under the License. use crate::{ - decode::{FlightRecordBatchStream}, - flight_service_client::FlightServiceClient, + decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, FlightDescriptor, FlightInfo, HandshakeRequest, Ticket, }; use bytes::Bytes; @@ -203,7 +202,9 @@ impl FlightClient { // convert to FlightError .map_err(|e| e.into()); - Ok(FlightRecordBatchStream::new_from_flight_data(response_stream)) + Ok(FlightRecordBatchStream::new_from_flight_data( + response_stream, + )) } /// Make a `GetFlightInfo` call to the server with the provided diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index ba3f7d3460fa..768f5eff7258 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::{compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; use arrow_flight::{ - decode::{DecodedPayload, FlightRecordBatchStream}, + decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, encode::{ prepare_batch_for_flight, prepare_schema_for_flight, FlightDataEncoderBuilder, }, @@ -182,7 +182,6 @@ async fn test_max_message_size_fuzz() { } } - #[tokio::test] async fn test_mismatched_record_batch_schema() { // send 2 batches with different schemas @@ -202,6 +201,66 @@ async fn test_mismatched_record_batch_schema() { ); } +#[tokio::test] +async fn test_chained_streams_batch_decoder() { + let batch1 = make_primative_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "ProtocolError(\"Unexpectedly saw multiple Schema messages in FlightData stream\")" + ); +} + +#[tokio::test] +async fn test_chained_streams_data_decoder() { + let batch1 = make_primative_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending two flight streams back to back, with different schemas + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])); + + // append the two streams (so they will have two different schema messages) + let encode_stream = encode_stream1.chain(encode_stream2); + + // lower level decode stream can handle multiple schema messages + let decode_stream = FlightDataDecoder::new(encode_stream); + + let decoded_data: Vec<_> = + decode_stream.try_collect().await.expect("encode / decode"); + + println!("decoded data: {decoded_data:#?}"); + + // expect two schema messages with the data + assert_eq!(decoded_data.len(), 4); + assert!(matches!(decoded_data[0].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[1].payload, + DecodedPayload::RecordBatch(_) + )); + assert!(matches!(decoded_data[2].payload, DecodedPayload::Schema(_))); + assert!(matches!( + decoded_data[3].payload, + DecodedPayload::RecordBatch(_) + )); +} /// Make a primtive batch for testing /// From 28661917786b41769452556fe155cb912d3f87bf Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 08:54:38 -0600 Subject: [PATCH 06/16] Add mismatched schema and data test --- arrow-flight/tests/encode_decode.rs | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 768f5eff7258..fc16f1ec9b38 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -262,6 +262,37 @@ async fn test_chained_streams_data_decoder() { )); } +#[tokio::test] +#[should_panic(expected = "assertion failed: idx < self.len()")] +async fn test_mismatched_schema_message() { + let batch1 = make_primative_batch(5); + let batch2 = make_dictionary_batch(3); + + // Model sending schema that is mismatched with the data + + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])) + // take only schema message from first stream + .take(1); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])) + // take only data message from second + .skip(1); + + // append the two streams + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; + + let err = result.unwrap_err(); + assert_eq!( + err.to_string(), + "ProtocolError(\"Unexpectedly saw multiple Schema messages in FlightData stream\")" + ); +} + /// Make a primtive batch for testing /// /// Example: From a6e61b1e10106508153d010ea0a86bb375f18afe Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 09:08:14 -0600 Subject: [PATCH 07/16] Add new test --- arrow-flight/tests/encode_decode.rs | 54 ++++++++++++++++------------- arrow-ipc/src/reader.rs | 6 ++++ 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index fc16f1ec9b38..48e1dbc421bd 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -263,34 +263,40 @@ async fn test_chained_streams_data_decoder() { } #[tokio::test] -#[should_panic(expected = "assertion failed: idx < self.len()")] async fn test_mismatched_schema_message() { - let batch1 = make_primative_batch(5); - let batch2 = make_dictionary_batch(3); - // Model sending schema that is mismatched with the data + // and expect an error + async fn do_test(batch1: RecordBatch, batch2: RecordBatch, expected: &str) { + let encode_stream1 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch1.clone())])) + // take only schema message from first stream + .take(1); + let encode_stream2 = FlightDataEncoderBuilder::default() + .build(futures::stream::iter(vec![Ok(batch2.clone())])) + // take only data message from second + .skip(1); + + // append the two streams + let encode_stream = encode_stream1.chain(encode_stream2); + + // FlightRecordBatchStream errors if the schema changes + let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); + let result: Result, FlightError> = decode_stream.try_collect().await; - let encode_stream1 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch1.clone())])) - // take only schema message from first stream - .take(1); - let encode_stream2 = FlightDataEncoderBuilder::default() - .build(futures::stream::iter(vec![Ok(batch2.clone())])) - // take only data message from second - .skip(1); - - // append the two streams - let encode_stream = encode_stream1.chain(encode_stream2); - - // FlightRecordBatchStream errors if the schema changes - let decode_stream = FlightRecordBatchStream::new_from_flight_data(encode_stream); - let result: Result, FlightError> = decode_stream.try_collect().await; + let err = result.unwrap_err().to_string(); + assert!( + err.contains(expected), + "could not find '{expected}' in '{err}'" + ); + } - let err = result.unwrap_err(); - assert_eq!( - err.to_string(), - "ProtocolError(\"Unexpectedly saw multiple Schema messages in FlightData stream\")" - ); + // primitive batch has more columns + do_test( + make_primative_batch(5), + make_dictionary_batch(3), + "Error decoding ipc RecordBatch: Io error: Invalid data for schema", + ) + .await; } /// Make a primtive batch for testing diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index ef0a49be693b..231f72910174 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -298,6 +298,12 @@ fn create_array( make_array(data) } _ => { + if nodes.len() <= node_index { + return Err(ArrowError::IoError(format!( + "Invalid data for schema. {} refers to node index {} but only {} in schema", + field, node_index, nodes.len() + ))); + } let array = create_primitive_array( nodes.get(node_index), data_type, From 2249871d5b511bcf362b1aca5e89850385cf884a Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 26 Dec 2022 09:09:42 -0600 Subject: [PATCH 08/16] more tests --- arrow-flight/tests/encode_decode.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 48e1dbc421bd..919d5c7ee043 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -290,13 +290,21 @@ async fn test_mismatched_schema_message() { ); } - // primitive batch has more columns + // primitive batch first (has more columns) do_test( make_primative_batch(5), make_dictionary_batch(3), "Error decoding ipc RecordBatch: Io error: Invalid data for schema", ) .await; + + // dictioanry batch first + do_test( + make_dictionary_batch(3), + make_primative_batch(5), + "Error decoding ipc RecordBatch: Invalid argument error", + ) + .await; } /// Make a primtive batch for testing From 4797c831241096e039cce54cba51e5d1d6ed44ff Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:37:31 -0500 Subject: [PATCH 09/16] Apply suggestions from code review Co-authored-by: Liang-Chi Hsieh Co-authored-by: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> --- arrow-flight/src/decode.rs | 2 +- arrow-flight/src/encode.rs | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/arrow-flight/src/decode.rs b/arrow-flight/src/decode.rs index 0b5529c253f4..cab52a434897 100644 --- a/arrow-flight/src/decode.rs +++ b/arrow-flight/src/decode.rs @@ -192,7 +192,7 @@ pub struct FlightDataDecoder { response: BoxStream<'static, Result>, /// Decoding state state: Option, - /// seen the end of the inner stream? + /// Seen the end of the inner stream? done: bool, } diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 70147e606f99..f6017dcc4829 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -73,7 +73,7 @@ pub struct FlightDataEncoderBuilder { /// Default target size for record batches to send. /// /// Note this value would normally be 4MB, but the size calculation is -/// somehwhat inexact, so we set it to 2MB. +/// somewhat inexact, so we set it to 2MB. pub const GRPC_TARGET_MAX_BATCH_SIZE: usize = 2097152; impl Default for FlightDataEncoderBuilder { @@ -92,7 +92,7 @@ impl FlightDataEncoderBuilder { } /// Set the (approximate) maximum encoded [`RecordBatch`] size to - /// limit the gRPC message size. Defaults fo 2MB. + /// limit the gRPC message size. Defaults to 2MB. /// /// The encoder splits up [`RecordBatch`]s (preserving order) to /// limit individual messages to approximately this size. The size @@ -104,7 +104,7 @@ impl FlightDataEncoderBuilder { self } - /// Specfy application specific metadata included in the + /// Specify application specific metadata included in the /// [`FlightData::app_metadata`] field of the the first Schema /// message pub fn with_metadata(mut self, app_metadata: Bytes) -> Self { @@ -142,7 +142,7 @@ pub struct FlightDataEncoder { inner: BoxStream<'static, Result>, /// schema, set after the first batch schema: Option, - /// Max sixe of batches to encode + /// Max size of batches to encode max_batch_size: usize, /// do the encoding / tracking of dictionaries encoder: FlightIpcEncoder, @@ -178,7 +178,7 @@ impl FlightDataEncoder { } /// Place the `FlightData` in the queue to send - fn queue_messages(&mut self, datas: Vec) { + fn queue_messages(&mut self, datas: impl IntoIterator) { for data in datas { self.queue_message(data) } @@ -271,7 +271,7 @@ impl Stream for FlightDataEncoder { /// Convert dictionary types to underlying types /// /// See hydrate_dictionary for more information -pub fn prepare_schema_for_flight(schema: &Schema) -> Schema { +fn prepare_schema_for_flight(schema: &Schema) -> Schema { let fields = schema .fields() .iter() @@ -299,7 +299,7 @@ pub fn split_batch_for_grpc_response( let size = batch .columns() .iter() - .map(|col| col.get_array_memory_size()) + .map(|col| col.get_buffer_memory_size()) .sum::(); let n_batches = From c5402ff4ad282bcab1fc960294c0b5fca70b8af5 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 08:43:27 -0600 Subject: [PATCH 10/16] Add From ArrowError impl for FlightError --- arrow-flight/src/encode.rs | 27 ++++++++++++++++----------- arrow-flight/src/error.rs | 6 ++++++ arrow-flight/tests/encode_decode.rs | 24 +++++++++++++++++++++--- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index f6017dcc4829..b725d58ab1f5 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -17,7 +17,7 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; -use crate::{error::FlightError, error::Result, FlightData, SchemaAsIpc}; +use crate::{error::Result, FlightData, SchemaAsIpc}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; @@ -178,7 +178,7 @@ impl FlightDataEncoder { } /// Place the `FlightData` in the queue to send - fn queue_messages(&mut self, datas: impl IntoIterator) { + fn queue_messages(&mut self, datas: impl IntoIterator) { for data in datas { self.queue_message(data) } @@ -292,6 +292,9 @@ fn prepare_schema_for_flight(schema: &Schema) -> Schema { /// Split [`RecordBatch`] so it hopefully fits into a gRPC response. /// /// Data is zero-copy sliced into batches. +/// +/// Note: this method does not take into account already sliced +/// arrays, pub fn split_batch_for_grpc_response( batch: RecordBatch, max_batch_size: usize, @@ -351,10 +354,11 @@ impl FlightIpcEncoder { &mut self, batch: &RecordBatch, ) -> Result<(Vec, FlightData)> { - let (encoded_dictionaries, encoded_batch) = self - .data_gen - .encoded_batch(batch, &mut self.dictionary_tracker, &self.options) - .map_err(FlightError::Arrow)?; + let (encoded_dictionaries, encoded_batch) = self.data_gen.encoded_batch( + batch, + &mut self.dictionary_tracker, + &self.options, + )?; let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); @@ -381,7 +385,7 @@ pub fn prepare_batch_for_flight( .map(hydrate_dictionary) .collect::>>()?; - RecordBatch::try_new(schema, columns).map_err(FlightError::Arrow) + Ok(RecordBatch::try_new(schema, columns)?) } /// Hydrates a dictionary to its underlying type @@ -398,11 +402,12 @@ pub fn prepare_batch_for_flight( /// /// For now we just hydrate the dictionaries to their underlying type fn hydrate_dictionary(array: &ArrayRef) -> Result { - if let DataType::Dictionary(_, value) = array.data_type() { - arrow_cast::cast(array, value).map_err(FlightError::Arrow) + let arr = if let DataType::Dictionary(_, value) = array.data_type() { + arrow_cast::cast(array, value)? } else { - Ok(Arc::clone(array)) - } + Arc::clone(array) + }; + Ok(arr) } #[cfg(test)] diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs index 7395c3362a83..11e0ae5c9fae 100644 --- a/arrow-flight/src/error.rs +++ b/arrow-flight/src/error.rs @@ -60,6 +60,12 @@ impl From for FlightError { } } +impl From for FlightError { + fn from(value: ArrowError) -> Self { + Self::Arrow(value) + } +} + // default conversion from FlightError to tonic treats everything // other than `Status` as an internal error impl From for tonic::Status { diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 919d5c7ee043..ea56b6fccf50 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -23,11 +23,10 @@ use arrow::{compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; use arrow_flight::{ decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, - encode::{ - prepare_batch_for_flight, prepare_schema_for_flight, FlightDataEncoderBuilder, - }, + encode::{prepare_batch_for_flight, FlightDataEncoderBuilder}, error::FlightError, }; +use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; use futures::{StreamExt, TryStreamExt}; @@ -406,3 +405,22 @@ async fn roundtrip_with_encoder( assert_eq!(expected_batches, output_batches); } + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_schema_for_flight(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.clone(), + }) + .collect(); + + Schema::new(fields) +} From 694265df0000899c3d69f87b20b7dc3c7e288a8d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 08:57:09 -0600 Subject: [PATCH 11/16] Correct make_dictionary_batch and add tests --- arrow-flight/tests/encode_decode.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index ea56b6fccf50..914a93f2afb6 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -336,10 +336,14 @@ fn make_primative_batch(num_rows: usize) -> RecordBatch { .unwrap() } +/// Make a dictionary batch for testing +/// +/// Example: +/// a: value0, value1, value2, None, value1, value2 fn make_dictionary_batch(num_rows: usize) -> RecordBatch { let values: Vec<_> = (0..num_rows) .map(|i| { - if i == i / 2 { + if i == num_rows / 2 { None } else { // repeat some values for low cardinality From ed1f85c4c4f30ca66a1cf24b795923147b176824 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:02:17 -0600 Subject: [PATCH 12/16] do not take --- arrow-flight/src/encode.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index b725d58ab1f5..2f0b105b2801 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -186,8 +186,8 @@ impl FlightDataEncoder { /// Encodes batch into one or more `FlightData` messages in self.queue fn encode_batch(&mut self, batch: RecordBatch) -> Result<()> { - let schema = match self.schema.take() { - Some(schema) => schema, + let schema = match &self.schema { + Some(schema) => schema.clone(), None => { let batch_schema = batch.schema(); // The first message is the schema message, and all @@ -200,13 +200,12 @@ impl FlightDataEncoder { schema_flight_data.app_metadata = app_metadata; } self.queue_message(schema_flight_data); + // remember schema + self.schema = Some(schema.clone()); schema } }; - // remember schema - self.schema = Some(schema.clone()); - // encode the batch let batch = prepare_batch_for_flight(&batch, schema)?; From 2267d13a55b25717d26b07c891521e3003416903 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:07:17 -0600 Subject: [PATCH 13/16] Make dictionary massaging non pub --- arrow-flight/src/encode.rs | 2 +- arrow-flight/tests/encode_decode.rs | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 2f0b105b2801..b7b72dcbe5dd 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -374,7 +374,7 @@ impl FlightIpcEncoder { /// 1. Hydrates any dictionaries to its underlying type. See /// hydrate_dictionary for more information. /// -pub fn prepare_batch_for_flight( +fn prepare_batch_for_flight( batch: &RecordBatch, schema: SchemaRef, ) -> Result { diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index 914a93f2afb6..45b8c0bf5ac9 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -23,10 +23,10 @@ use arrow::{compute::concat_batches, datatypes::Int32Type}; use arrow_array::{ArrayRef, DictionaryArray, Float64Array, RecordBatch, UInt8Array}; use arrow_flight::{ decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, - encode::{prepare_batch_for_flight, FlightDataEncoderBuilder}, + encode::FlightDataEncoderBuilder, error::FlightError, }; -use arrow_schema::{DataType, Field, Schema}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use bytes::Bytes; use futures::{StreamExt, TryStreamExt}; @@ -428,3 +428,26 @@ fn prepare_schema_for_flight(schema: &Schema) -> Schema { Schema::new(fields) } + +/// Workaround for https://github.com/apache/arrow-rs/issues/1206 +fn prepare_batch_for_flight( + batch: &RecordBatch, + schema: SchemaRef, +) -> Result { + let columns = batch + .columns() + .iter() + .map(hydrate_dictionary) + .collect::, _>>()?; + + Ok(RecordBatch::try_new(schema, columns)?) +} + +fn hydrate_dictionary(array: &ArrayRef) -> Result { + let arr = if let DataType::Dictionary(_, value) = array.data_type() { + arrow_cast::cast(array, value)? + } else { + Arc::clone(array) + }; + Ok(arr) +} From a6ba713b249a3c06542ffe19023b8fcec7993b14 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:13:32 -0600 Subject: [PATCH 14/16] Add comment about memory size and make split function non pub --- arrow-flight/src/encode.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index b7b72dcbe5dd..da8fde053610 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -293,8 +293,8 @@ fn prepare_schema_for_flight(schema: &Schema) -> Schema { /// Data is zero-copy sliced into batches. /// /// Note: this method does not take into account already sliced -/// arrays, -pub fn split_batch_for_grpc_response( +/// arrays: +fn split_batch_for_grpc_response( batch: RecordBatch, max_batch_size: usize, ) -> Vec { From 4f30ab71c36c44860b83346b0d7bd0672d84b095 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:15:56 -0600 Subject: [PATCH 15/16] explicitly return early from encode stream --- arrow-flight/src/encode.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index da8fde053610..6cc4ff07c2af 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -245,6 +245,9 @@ impl Stream for FlightDataEncoder { None => { // inner is done self.done = true; + // queue must also be empty so we are done + assert!(self.queue.is_empty()); + return Poll::Ready(None); } Some(Err(e)) => { // error from inner From ba5e698b0d9189209c434a880e132aa2f1f74967 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 29 Dec 2022 09:41:46 -0600 Subject: [PATCH 16/16] fix doc link --- arrow-flight/src/encode.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 6cc4ff07c2af..7c339b67d488 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -60,6 +60,7 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// ``` /// /// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get +/// [`FlightError`]: crate::error::FlightError #[derive(Debug)] pub struct FlightDataEncoderBuilder { /// The maximum message size (see details on [`Self::with_max_message_size`]).