diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index 840e8e8a93cc..1ca956e2c73f 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -70,11 +70,6 @@ use csv::ByteRecord; use std::io::Write; use crate::map_csv_error; - -const DEFAULT_DATE_FORMAT: &str = "%F"; -const DEFAULT_TIME_FORMAT: &str = "%T"; -const DEFAULT_TIMESTAMP_FORMAT: &str = "%FT%H:%M:%S.%9f"; -const DEFAULT_TIMESTAMP_TZ_FORMAT: &str = "%FT%H:%M:%S.%9f%:z"; const DEFAULT_NULL_VALUE: &str = ""; /// A CSV writer @@ -82,41 +77,29 @@ const DEFAULT_NULL_VALUE: &str = ""; pub struct Writer { /// The object to write to writer: csv::Writer, - /// Whether file should be written with headers. Defaults to `true` + /// Whether file should be written with headers, defaults to `true` has_headers: bool, - /// The date format for date arrays + /// The date format for date arrays, defaults to RFC3339 date_format: Option, - /// The datetime format for datetime arrays + /// The datetime format for datetime arrays, defaults to RFC3339 datetime_format: Option, - /// The timestamp format for timestamp arrays + /// The timestamp format for timestamp arrays, defaults to RFC3339 timestamp_format: Option, - /// The timestamp format for timestamp (with timezone) arrays + /// The timestamp format for timestamp (with timezone) arrays, defaults to RFC3339 timestamp_tz_format: Option, - /// The time format for time arrays + /// The time format for time arrays, defaults to RFC3339 time_format: Option, /// Is the beginning-of-writer beginning: bool, - /// The value to represent null entries - null_value: String, + /// The value to represent null entries, defaults to [`DEFAULT_NULL_VALUE`] + null_value: Option, } impl Writer { /// Create a new CsvWriter from a writable object, with default options pub fn new(writer: W) -> Self { let delimiter = b','; - let mut builder = csv::WriterBuilder::new(); - let writer = builder.delimiter(delimiter).from_writer(writer); - Writer { - writer, - has_headers: true, - date_format: Some(DEFAULT_DATE_FORMAT.to_string()), - datetime_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - time_format: Some(DEFAULT_TIME_FORMAT.to_string()), - timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - timestamp_tz_format: Some(DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), - beginning: true, - null_value: DEFAULT_NULL_VALUE.to_string(), - } + WriterBuilder::new().with_delimiter(delimiter).build(writer) } /// Write a vector of record batches to a writable object @@ -138,7 +121,7 @@ impl Writer { } let options = FormatOptions::default() - .with_null(&self.null_value) + .with_null(self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE)) .with_date_format(self.date_format.as_deref()) .with_datetime_format(self.datetime_format.as_deref()) .with_timestamp_format(self.timestamp_format.as_deref()) @@ -207,9 +190,9 @@ impl RecordBatchWriter for Writer { #[derive(Clone, Debug)] pub struct WriterBuilder { /// Optional column delimiter. Defaults to `b','` - delimiter: Option, + delimiter: u8, /// Whether to write column names as file headers. Defaults to `true` - has_headers: bool, + has_header: bool, /// Optional date format for date arrays date_format: Option, /// Optional datetime format for datetime arrays @@ -227,14 +210,14 @@ pub struct WriterBuilder { impl Default for WriterBuilder { fn default() -> Self { Self { - has_headers: true, - delimiter: None, - date_format: Some(DEFAULT_DATE_FORMAT.to_string()), - datetime_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - time_format: Some(DEFAULT_TIME_FORMAT.to_string()), - timestamp_format: Some(DEFAULT_TIMESTAMP_FORMAT.to_string()), - timestamp_tz_format: Some(DEFAULT_TIMESTAMP_TZ_FORMAT.to_string()), - null_value: Some(DEFAULT_NULL_VALUE.to_string()), + has_header: true, + delimiter: b',', + date_format: None, + datetime_format: None, + time_format: None, + timestamp_format: None, + timestamp_tz_format: None, + null_value: None, } } } @@ -254,7 +237,7 @@ impl WriterBuilder { /// let file = File::create("target/out.csv").unwrap(); /// /// // create a builder that doesn't write headers - /// let builder = WriterBuilder::new().has_headers(false); + /// let builder = WriterBuilder::new().with_header(false); /// let writer = builder.build(file); /// /// writer @@ -265,48 +248,92 @@ impl WriterBuilder { } /// Set whether to write headers + #[deprecated(note = "Use Self::with_header")] + #[doc(hidden)] pub fn has_headers(mut self, has_headers: bool) -> Self { - self.has_headers = has_headers; + self.has_header = has_headers; + self + } + + /// Set whether to write the CSV file with a header + pub fn with_header(mut self, header: bool) -> Self { + self.has_header = header; self } + /// Returns `true` if this writer is configured to write a header + pub fn header(&self) -> bool { + self.has_header + } + /// Set the CSV file's column delimiter as a byte character pub fn with_delimiter(mut self, delimiter: u8) -> Self { - self.delimiter = Some(delimiter); + self.delimiter = delimiter; self } + /// Get the CSV file's column delimiter as a byte character + pub fn delimiter(&self) -> u8 { + self.delimiter + } + /// Set the CSV file's date format pub fn with_date_format(mut self, format: String) -> Self { self.date_format = Some(format); self } + /// Get the CSV file's date format if set, defaults to RFC3339 + pub fn date_format(&self) -> Option<&str> { + self.date_format.as_deref() + } + /// Set the CSV file's datetime format pub fn with_datetime_format(mut self, format: String) -> Self { self.datetime_format = Some(format); self } + /// Get the CSV file's datetime format if set, defaults to RFC3339 + pub fn datetime_format(&self) -> Option<&str> { + self.datetime_format.as_deref() + } + /// Set the CSV file's time format pub fn with_time_format(mut self, format: String) -> Self { self.time_format = Some(format); self } + /// Get the CSV file's datetime time if set, defaults to RFC3339 + pub fn time_format(&self) -> Option<&str> { + self.time_format.as_deref() + } + /// Set the CSV file's timestamp format pub fn with_timestamp_format(mut self, format: String) -> Self { self.timestamp_format = Some(format); self } + /// Get the CSV file's timestamp format if set, defaults to RFC3339 + pub fn timestamp_format(&self) -> Option<&str> { + self.timestamp_format.as_deref() + } + /// Set the value to represent null in output pub fn with_null(mut self, null_value: String) -> Self { self.null_value = Some(null_value); self } - /// Use RFC3339 format for date/time/timestamps + /// Get the value to represent null in output + pub fn null(&self) -> &str { + self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) + } + + /// Use RFC3339 format for date/time/timestamps (default) + #[deprecated(note = "Use WriterBuilder::default()")] pub fn with_rfc3339(mut self) -> Self { self.date_format = None; self.datetime_format = None; @@ -318,21 +345,18 @@ impl WriterBuilder { /// Create a new `Writer` pub fn build(self, writer: W) -> Writer { - let delimiter = self.delimiter.unwrap_or(b','); let mut builder = csv::WriterBuilder::new(); - let writer = builder.delimiter(delimiter).from_writer(writer); + let writer = builder.delimiter(self.delimiter).from_writer(writer); Writer { writer, - has_headers: self.has_headers, + beginning: true, + has_headers: self.has_header, date_format: self.date_format, datetime_format: self.datetime_format, time_format: self.time_format, timestamp_format: self.timestamp_format, timestamp_tz_format: self.timestamp_tz_format, - beginning: true, - null_value: self - .null_value - .unwrap_or_else(|| DEFAULT_NULL_VALUE.to_string()), + null_value: self.null_value, } } } @@ -411,11 +435,11 @@ mod tests { let expected = r#"c1,c2,c3,c4,c5,c6,c7 Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo "#; assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); } @@ -512,7 +536,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo let mut file = tempfile::tempfile().unwrap(); let builder = WriterBuilder::new() - .has_headers(false) + .with_header(false) .with_delimiter(b'|') .with_null("NULL".to_string()) .with_time_format("%r".to_string()); @@ -560,7 +584,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo ) .unwrap(); - let builder = WriterBuilder::new().has_headers(false); + let builder = WriterBuilder::new().with_header(false); let mut buf: Cursor> = Default::default(); // drop the writer early to release the borrow. @@ -652,7 +676,7 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo let mut file = tempfile::tempfile().unwrap(); - let builder = WriterBuilder::new().with_rfc3339(); + let builder = WriterBuilder::new(); let mut writer = builder.build(&mut file); let batches = vec![&batch]; for batch in batches { diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index edaa7129dc9a..70227eedea0e 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -50,7 +50,7 @@ tonic = { version = "0.10.0", default-features = false, features = ["transport", # CLI-related dependencies anyhow = { version = "1.0", optional = true } -clap = { version = "4.1", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage"], optional = true } +clap = { version = "4.4.6", default-features = false, features = ["std", "derive", "env", "help", "error-context", "usage", "wrap_help", "color", "suggestions"], optional = true } tracing-log = { version = "0.1", optional = true } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["ansi", "env-filter", "fmt"], optional = true } diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 9194b209fe72..b80772ac927e 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -44,5 +44,33 @@ that demonstrate how to build a Flight server implemented with [tonic](https://d ## Feature Flags - `flight-sql-experimental`: Enables experimental support for - [Apache Arrow FlightSQL](https://arrow.apache.org/docs/format/FlightSql.html), - a protocol for interacting with SQL databases. + [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. + +## CLI + +This crates offers a basic [Apache Arrow FlightSQL] command line interface. + +The client can be installed from the repository: + +```console +$ cargo install --features=cli,flight-sql-experimental,tls --bin=flight_sql_client --path=. --locked +``` + +The client comes with extensive help text: + +```console +$ flight_sql_client help +``` + +A query can be executed using: + +```console +$ flight_sql_client --host example.com statement-query "SELECT 1;" ++----------+ +| Int64(1) | ++----------+ +| 1 | ++----------+ +``` + +[apache arrow flightsql]: https://arrow.apache.org/docs/format/FlightSql.html diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index df51530b3c8f..296efc1c308e 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{error::Error, sync::Arc, time::Duration}; +use std::{sync::Arc, time::Duration}; use anyhow::{bail, Context, Result}; use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; @@ -30,45 +30,17 @@ use tonic::{ }; use tracing_log::log::info; -/// A ':' separated key value pair -#[derive(Debug, Clone)] -struct KeyValue { - pub key: K, - pub value: V, -} - -impl std::str::FromStr for KeyValue -where - K: std::str::FromStr, - V: std::str::FromStr, - K::Err: std::fmt::Display, - V::Err: std::fmt::Display, -{ - type Err = String; - - fn from_str(s: &str) -> std::result::Result { - let parts = s.splitn(2, ':').collect::>(); - match parts.as_slice() { - [key, value] => { - let key = K::from_str(key).map_err(|e| e.to_string())?; - let value = V::from_str(value.trim()).map_err(|e| e.to_string())?; - Ok(Self { key, value }) - } - _ => Err(format!( - "Invalid key value pair - expected 'KEY:VALUE' got '{s}'" - )), - } - } -} - /// Logging CLI config. #[derive(Debug, Parser)] pub struct LoggingArgs { /// Log verbosity. /// - /// Use `-v for warn, `-vv for info, -vvv for debug, -vvvv for trace. + /// Defaults to "warn". /// - /// Note you can also set logging level using `RUST_LOG` environment variable: `RUST_LOG=debug` + /// Use `-v` for "info", `-vv` for "debug", `-vvv` for "trace". + /// + /// Note you can also set logging level using `RUST_LOG` environment variable: + /// `RUST_LOG=debug`. #[clap( short = 'v', long = "verbose", @@ -81,16 +53,22 @@ pub struct LoggingArgs { struct ClientArgs { /// Additional headers. /// - /// Values should be key value pairs separated by ':' - #[clap(long, value_delimiter = ',')] - headers: Vec>, + /// Can be given multiple times. Headers and values are separated by '='. + /// + /// Example: `-H foo=bar -H baz=42` + #[clap(long = "header", short = 'H', value_parser = parse_key_val)] + headers: Vec<(String, String)>, - /// Username - #[clap(long)] + /// Username. + /// + /// Optional. If given, `password` must also be set. + #[clap(long, requires = "password")] username: Option, - /// Password - #[clap(long)] + /// Password. + /// + /// Optional. If given, `username` must also be set. + #[clap(long, requires = "username")] password: Option, /// Auth token. @@ -98,14 +76,20 @@ struct ClientArgs { token: Option, /// Use TLS. + /// + /// If not provided, use cleartext connection. #[clap(long)] tls: bool, /// Server host. + /// + /// Required. #[clap(long)] host: String, /// Server port. + /// + /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. #[clap(long)] port: Option, } @@ -124,13 +108,34 @@ struct Args { cmd: Command, } +/// Different available commands. #[derive(Debug, Subcommand)] enum Command { + /// Execute given statement. StatementQuery { + /// SQL query. + /// + /// Required. query: String, }, + + /// Prepare given statement and then execute it. PreparedStatementQuery { + /// SQL query. + /// + /// Required. + /// + /// Can contains placeholders like `$1`. + /// + /// Example: `SELECT * FROM t WHERE x = $1` query: String, + + /// Additional parameters. + /// + /// Can be given multiple times. Names and values are separated by '='. Values will be + /// converted to the type that the server reported for the prepared statement. + /// + /// Example: `-p $1=42` #[clap(short, value_parser = parse_key_val)] params: Vec<(String, String)>, }, @@ -284,8 +289,8 @@ async fn setup_client(args: ClientArgs) -> Result Result Result<(String, String), Box> { +fn parse_key_val(s: &str) -> Result<(String, String), String> { let pos = s .find('=') .ok_or_else(|| format!("invalid KEY=value: no `=` found in `{s}`"))?; - Ok((s[..pos].parse()?, s[pos + 1..].parse()?)) + Ok((s[..pos].to_owned(), s[pos + 1..].to_owned())) } /// Log headers/trailers. diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index cd2ee7c02b68..9ae7f1637982 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -30,10 +30,17 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// This can be used to implement [`FlightService::do_get`] in an /// Arrow Flight implementation; /// +/// This structure encodes a stream of `Result`s rather than `RecordBatch`es to +/// propagate errors from streaming execution, where the generation of the +/// `RecordBatch`es is incremental, and an error may occur even after +/// several have already been successfully produced. +/// /// # Caveats -/// 1. [`DictionaryArray`](arrow_array::array::DictionaryArray)s -/// are converted to their underlying types prior to transport, due to -/// . +/// 1. When [`DictionaryHandling`] is [`DictionaryHandling::Hydrate`], [`DictionaryArray`](arrow_array::array::DictionaryArray)s +/// are converted to their underlying types prior to transport. +/// When [`DictionaryHandling`] is [`DictionaryHandling::Resend`], Dictionary [`FlightData`] is sent with every +/// [`RecordBatch`] that contains a [`DictionaryArray`](arrow_array::array::DictionaryArray). +/// See . /// /// # Example /// ```no_run @@ -41,14 +48,14 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; /// # async fn f() { /// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); -/// # let record_batch = RecordBatch::try_from_iter(vec![ +/// # let batch = RecordBatch::try_from_iter(vec![ /// # ("a", Arc::new(c1) as ArrayRef) /// # ]) /// # .expect("cannot create record batch"); /// use arrow_flight::encode::FlightDataEncoderBuilder; /// /// // Get an input stream of Result -/// let input_stream = futures::stream::iter(vec![Ok(record_batch)]); +/// let input_stream = futures::stream::iter(vec![Ok(batch)]); /// /// // Build a stream of `Result` (e.g. to return for do_get) /// let flight_data_stream = FlightDataEncoderBuilder::new() @@ -59,6 +66,39 @@ use futures::{ready, stream::BoxStream, Stream, StreamExt}; /// # } /// ``` /// +/// # Example: Sending `Vec` +/// +/// You can create a [`Stream`] to pass to [`Self::build`] from an existing +/// `Vec` of `RecordBatch`es like this: +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +/// # async fn f() { +/// # fn make_batches() -> Vec { +/// # let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); +/// # let batch = RecordBatch::try_from_iter(vec![ +/// # ("a", Arc::new(c1) as ArrayRef) +/// # ]) +/// # .expect("cannot create record batch"); +/// # vec![batch.clone(), batch.clone()] +/// # } +/// use arrow_flight::encode::FlightDataEncoderBuilder; +/// +/// // Get batches that you want to send via Flight +/// let batches: Vec = make_batches(); +/// +/// // Create an input stream of Result +/// let input_stream = futures::stream::iter( +/// batches.into_iter().map(Ok) +/// ); +/// +/// // Build a stream of `Result` (e.g. to return for do_get) +/// let flight_data_stream = FlightDataEncoderBuilder::new() +/// .build(input_stream); +/// # } +/// ``` +/// /// [`FlightService::do_get`]: crate::flight_service_server::FlightService::do_get /// [`FlightError`]: crate::error::FlightError #[derive(Debug)] @@ -74,6 +114,9 @@ pub struct FlightDataEncoderBuilder { schema: Option, /// Optional flight descriptor, if known before data. descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, } /// Default target size for encoded [`FlightData`]. @@ -90,6 +133,7 @@ impl Default for FlightDataEncoderBuilder { app_metadata: Bytes::new(), schema: None, descriptor: None, + dictionary_handling: DictionaryHandling::Hydrate, } } } @@ -114,6 +158,15 @@ impl FlightDataEncoderBuilder { self } + /// Set [`DictionaryHandling`] for encoder + pub fn with_dictionary_handling( + mut self, + dictionary_handling: DictionaryHandling, + ) -> Self { + self.dictionary_handling = dictionary_handling; + self + } + /// Specify application specific metadata included in the /// [`FlightData::app_metadata`] field of the the first Schema /// message @@ -146,8 +199,10 @@ impl FlightDataEncoderBuilder { self } - /// Return a [`Stream`] of [`FlightData`], - /// consuming self. More details on [`FlightDataEncoder`] + /// Takes a [`Stream`] of [`Result`] and returns a [`Stream`] + /// of [`FlightData`], consuming self. + /// + /// See example on [`Self`] and [`FlightDataEncoder`] for more details pub fn build(self, input: S) -> FlightDataEncoder where S: Stream> + Send + 'static, @@ -158,6 +213,7 @@ impl FlightDataEncoderBuilder { app_metadata, schema, descriptor, + dictionary_handling, } = self; FlightDataEncoder::new( @@ -167,6 +223,7 @@ impl FlightDataEncoderBuilder { options, app_metadata, descriptor, + dictionary_handling, ) } } @@ -192,6 +249,9 @@ pub struct FlightDataEncoder { done: bool, /// cleared after the first FlightData message is sent descriptor: Option, + /// Deterimines how `DictionaryArray`s are encoded for transport. + /// See [`DictionaryHandling`] for more information. + dictionary_handling: DictionaryHandling, } impl FlightDataEncoder { @@ -202,16 +262,21 @@ impl FlightDataEncoder { options: IpcWriteOptions, app_metadata: Bytes, descriptor: Option, + dictionary_handling: DictionaryHandling, ) -> Self { let mut encoder = Self { inner, schema: None, max_flight_data_size, - encoder: FlightIpcEncoder::new(options), + encoder: FlightIpcEncoder::new( + options, + dictionary_handling != DictionaryHandling::Resend, + ), app_metadata: Some(app_metadata), queue: VecDeque::new(), done: false, descriptor, + dictionary_handling, }; // If schema is known up front, enqueue it immediately @@ -242,7 +307,8 @@ impl FlightDataEncoder { fn encode_schema(&mut self, schema: &SchemaRef) -> SchemaRef { // The first message is the schema message, and all // batches have the same schema - let schema = Arc::new(prepare_schema_for_flight(schema)); + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let schema = Arc::new(prepare_schema_for_flight(schema, send_dictionaries)); let mut schema_flight_data = self.encoder.encode_schema(&schema); // attach any metadata requested @@ -264,7 +330,8 @@ impl FlightDataEncoder { }; // encode the batch - let batch = prepare_batch_for_flight(&batch, schema)?; + let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; + let batch = prepare_batch_for_flight(&batch, schema, send_dictionaries)?; for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { let (flight_dictionaries, flight_batch) = @@ -325,17 +392,46 @@ impl Stream for FlightDataEncoder { } } +/// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s +/// +/// [`DictionaryArray`]: arrow_array::DictionaryArray +#[derive(Debug, PartialEq)] +pub enum DictionaryHandling { + /// Expands to the underlying type (default). This likely sends more data + /// over the network but requires less memory (dictionaries are not tracked) + /// and is more compatible with other arrow flight client implementations + /// that may not support `DictionaryEncoding` + /// + /// An IPC response, streaming or otherwise, defines its schema up front + /// which defines the mapping from dictionary IDs. It then sends these + /// dictionaries over the wire. + /// + /// This requires identifying the different dictionaries in use, assigning + /// them IDs, and sending new dictionaries, delta or otherwise, when needed + /// + /// See also: + /// * + Hydrate, + /// Send dictionary FlightData with every RecordBatch that contains a + /// [`DictionaryArray`]. See [`Self::Hydrate`] for more tradeoffs. No + /// attempt is made to skip sending the same (logical) dictionary values + /// twice. + /// + /// [`DictionaryArray`]: arrow_array::DictionaryArray + Resend, +} + /// Prepare an arrow Schema for transport over the Arrow Flight protocol /// /// Convert dictionary types to underlying types /// /// See hydrate_dictionary for more information -fn prepare_schema_for_flight(schema: &Schema) -> Schema { +fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema { let fields: Fields = schema .fields() .iter() .map(|field| match field.data_type() { - DataType::Dictionary(_, value_type) => Field::new( + DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( field.name(), value_type.as_ref().clone(), field.is_nullable(), @@ -394,8 +490,7 @@ struct FlightIpcEncoder { } impl FlightIpcEncoder { - fn new(options: IpcWriteOptions) -> Self { - let error_on_replacement = true; + fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { Self { options, data_gen: IpcDataGenerator::default(), @@ -438,12 +533,14 @@ impl FlightIpcEncoder { fn prepare_batch_for_flight( batch: &RecordBatch, schema: SchemaRef, + send_dictionaries: bool, ) -> Result { let columns = batch .columns() .iter() - .map(hydrate_dictionary) + .map(|c| hydrate_dictionary(c, send_dictionaries)) .collect::>>()?; + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); Ok(RecordBatch::try_new_with_options( @@ -451,35 +548,28 @@ fn prepare_batch_for_flight( )?) } -/// Hydrates a dictionary to its underlying type -/// -/// An IPC response, streaming or otherwise, defines its schema up front -/// which defines the mapping from dictionary IDs. It then sends these -/// dictionaries over the wire. -/// -/// This requires identifying the different dictionaries in use, assigning -/// them IDs, and sending new dictionaries, delta or otherwise, when needed -/// -/// See also: -/// * -/// -/// For now we just hydrate the dictionaries to their underlying type -fn hydrate_dictionary(array: &ArrayRef) -> Result { - let arr = if let DataType::Dictionary(_, value) = array.data_type() { - arrow_cast::cast(array, value)? - } else { - Arc::clone(array) +/// Hydrates a dictionary to its underlying type if send_dictionaries is false. If send_dictionaries +/// is true, dictionaries are sent with every batch which is not as optimal as described in [DictionaryHandling::Hydrate] above, +/// but does enable sending DictionaryArray's via Flight. +fn hydrate_dictionary(array: &ArrayRef, send_dictionaries: bool) -> Result { + let arr = match array.data_type() { + DataType::Dictionary(_, value) if !send_dictionaries => { + arrow_cast::cast(array, value)? + } + _ => Arc::clone(array), }; Ok(arr) } #[cfg(test)] mod tests { - use arrow_array::types::*; use arrow_array::*; + use arrow_array::{cast::downcast_array, types::*}; use arrow_cast::pretty::pretty_format_batches; use std::collections::HashMap; + use crate::decode::{DecodedPayload, FlightDataDecoder}; + use super::*; #[test] @@ -497,7 +587,7 @@ mod tests { let big_batch = batch.slice(0, batch.num_rows() - 1); let optimized_big_batch = - prepare_batch_for_flight(&big_batch, Arc::clone(&schema)) + prepare_batch_for_flight(&big_batch, Arc::clone(&schema), false) .expect("failed to optimize"); let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); @@ -509,7 +599,7 @@ mod tests { let small_batch = batch.slice(0, 1); let optimized_small_batch = - prepare_batch_for_flight(&small_batch, Arc::clone(&schema)) + prepare_batch_for_flight(&small_batch, Arc::clone(&schema), false) .expect("failed to optimize"); let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); @@ -520,6 +610,84 @@ mod tests { ); } + #[tokio::test] + async fn test_dictionary_hydration() { + let arr: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap(); + let encoder = FlightDataEncoderBuilder::default() + .build(futures::stream::once(async { Ok(batch) })); + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = + Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); + let expected_schema = Arc::new(expected_schema); + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = StringArray::from(vec!["a", "a", "b"]); + let actual_array = b.column_by_name("dict").unwrap(); + let actual_array = downcast_array::(actual_array); + + assert_eq!(actual_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_send_dictionaries() { + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( + "dict", + DataType::UInt16, + DataType::Utf8, + false, + )])); + + let arr_one: Arc> = + Arc::new(vec!["a", "a", "b"].into_iter().collect()); + let arr_two: Arc> = + Arc::new(vec!["b", "a", "c"].into_iter().collect()); + let batch_one = + RecordBatch::try_new(schema.clone(), vec![arr_one.clone()]).unwrap(); + let batch_two = + RecordBatch::try_new(schema.clone(), vec![arr_two.clone()]).unwrap(); + + let encoder = FlightDataEncoderBuilder::default() + .with_dictionary_handling(DictionaryHandling::Resend) + .build(futures::stream::iter(vec![Ok(batch_one), Ok(batch_two)])); + + let mut decoder = FlightDataDecoder::new(encoder); + let mut expected_array = arr_one; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), schema); + + let actual_array = + Arc::new(downcast_array::>( + b.column_by_name("dict").unwrap(), + )); + + assert_eq!(actual_array, expected_array); + + expected_array = arr_two.clone(); + } + } + } + } + #[test] fn test_schema_metadata_encoded() { let schema = @@ -527,7 +695,7 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - let got = prepare_schema_for_flight(&schema); + let got = prepare_schema_for_flight(&schema, false); assert!(got.metadata().contains_key("some_key")); } @@ -540,7 +708,8 @@ mod tests { ) .expect("cannot create record batch"); - prepare_batch_for_flight(&batch, batch.schema()).expect("failed to optimize"); + prepare_batch_for_flight(&batch, batch.schema(), false) + .expect("failed to optimize"); } pub fn make_flight_data( diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml index b5f66294a7c7..83ad044d25e7 100644 --- a/arrow-ipc/Cargo.toml +++ b/arrow-ipc/Cargo.toml @@ -41,7 +41,7 @@ arrow-data = { workspace = true } arrow-schema = { workspace = true } flatbuffers = { version = "23.1.21", default-features = false } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } -zstd = { version = "0.12.0", default-features = false, optional = true } +zstd = { version = "0.13.0", default-features = false, optional = true } [features] default = [] diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index 4e98e2fd873a..c1cef0ec81b4 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -17,9 +17,13 @@ //! JSON reader //! -//! This JSON reader allows JSON line-delimited files to be read into the Arrow memory -//! model. Records are loaded in batches and are then converted from row-based data to -//! columnar data. +//! This JSON reader allows JSON records to be read into the Arrow memory +//! model. Records are loaded in batches and are then converted from the record-oriented +//! representation to the columnar arrow data model. +//! +//! The reader ignores whitespace between JSON values, including `\n` and `\r`, allowing +//! parsing of sequences of one or more arbitrarily formatted JSON values, including +//! but not limited to newline-delimited JSON. //! //! # Basic Usage //! @@ -130,6 +134,7 @@ //! use std::io::BufRead; +use std::sync::Arc; use chrono::Utc; use serde::Serialize; @@ -137,9 +142,11 @@ use serde::Serialize; use arrow_array::timezone::Tz; use arrow_array::types::Float32Type; use arrow_array::types::*; -use arrow_array::{downcast_integer, RecordBatch, RecordBatchReader, StructArray}; +use arrow_array::{ + downcast_integer, make_array, RecordBatch, RecordBatchReader, StructArray, +}; use arrow_data::ArrayData; -use arrow_schema::{ArrowError, DataType, SchemaRef, TimeUnit}; +use arrow_schema::{ArrowError, DataType, FieldRef, Schema, SchemaRef, TimeUnit}; pub use schema::*; use crate::reader::boolean_array::BooleanArrayDecoder; @@ -150,7 +157,7 @@ use crate::reader::null_array::NullArrayDecoder; use crate::reader::primitive_array::PrimitiveArrayDecoder; use crate::reader::string_array::StringArrayDecoder; use crate::reader::struct_array::StructArrayDecoder; -use crate::reader::tape::{Tape, TapeDecoder, TapeElement}; +use crate::reader::tape::{Tape, TapeDecoder}; use crate::reader::timestamp_array::TimestampArrayDecoder; mod boolean_array; @@ -171,6 +178,7 @@ pub struct ReaderBuilder { batch_size: usize, coerce_primitive: bool, strict_mode: bool, + is_field: bool, schema: SchemaRef, } @@ -189,10 +197,51 @@ impl ReaderBuilder { batch_size: 1024, coerce_primitive: false, strict_mode: false, + is_field: false, schema, } } + /// Create a new [`ReaderBuilder`] that will parse JSON values of `field.data_type()` + /// + /// Unlike [`ReaderBuilder::new`] this does not require the root of the JSON data + /// to be an object, i.e. `{..}`, allowing for parsing of any valid JSON value(s) + /// + /// ``` + /// # use std::sync::Arc; + /// # use arrow_array::cast::AsArray; + /// # use arrow_array::types::Int32Type; + /// # use arrow_json::ReaderBuilder; + /// # use arrow_schema::{DataType, Field}; + /// // Root of JSON schema is a numeric type + /// let data = "1\n2\n3\n"; + /// let field = Arc::new(Field::new("int", DataType::Int32, true)); + /// let mut reader = ReaderBuilder::new_with_field(field.clone()).build(data.as_bytes()).unwrap(); + /// let b = reader.next().unwrap().unwrap(); + /// let values = b.column(0).as_primitive::().values(); + /// assert_eq!(values, &[1, 2, 3]); + /// + /// // Root of JSON schema is a list type + /// let data = "[1, 2, 3, 4, 5, 6, 7]\n[1, 2, 3]"; + /// let field = Field::new_list("int", field.clone(), true); + /// let mut reader = ReaderBuilder::new_with_field(field).build(data.as_bytes()).unwrap(); + /// let b = reader.next().unwrap().unwrap(); + /// let list = b.column(0).as_list::(); + /// + /// assert_eq!(list.offsets().as_ref(), &[0, 7, 10]); + /// let list_values = list.values().as_primitive::(); + /// assert_eq!(list_values.values(), &[1, 2, 3, 4, 5, 6, 7, 1, 2, 3]); + /// ``` + pub fn new_with_field(field: impl Into) -> Self { + Self { + batch_size: 1024, + coerce_primitive: false, + strict_mode: false, + is_field: true, + schema: Arc::new(Schema::new([field.into()])), + } + } + /// Sets the batch size in rows to read pub fn with_batch_size(self, batch_size: usize) -> Self { Self { batch_size, ..self } @@ -233,16 +282,22 @@ impl ReaderBuilder { /// Create a [`Decoder`] pub fn build_decoder(self) -> Result { - let decoder = make_decoder( - DataType::Struct(self.schema.fields.clone()), - self.coerce_primitive, - self.strict_mode, - false, - )?; + let (data_type, nullable) = match self.is_field { + false => (DataType::Struct(self.schema.fields.clone()), false), + true => { + let field = &self.schema.fields[0]; + (field.data_type().clone(), field.is_nullable()) + } + }; + + let decoder = + make_decoder(data_type, self.coerce_primitive, self.strict_mode, nullable)?; + let num_fields = self.schema.all_fields().len(); Ok(Decoder { decoder, + is_field: self.is_field, tape_decoder: TapeDecoder::new(self.batch_size, num_fields), batch_size: self.batch_size, schema: self.schema, @@ -344,6 +399,7 @@ pub struct Decoder { tape_decoder: TapeDecoder, decoder: Box, batch_size: usize, + is_field: bool, schema: SchemaRef, } @@ -563,24 +619,20 @@ impl Decoder { let mut next_object = 1; let pos: Vec<_> = (0..tape.num_rows()) .map(|_| { - let end = match tape.get(next_object) { - TapeElement::StartObject(end) => end, - _ => unreachable!("corrupt tape"), - }; - std::mem::replace(&mut next_object, end + 1) + let next = tape.next(next_object, "row").unwrap(); + std::mem::replace(&mut next_object, next) }) .collect(); let decoded = self.decoder.decode(&tape, &pos)?; self.tape_decoder.clear(); - // Sanity check - assert!(matches!(decoded.data_type(), DataType::Struct(_))); - assert_eq!(decoded.null_count(), 0); - assert_eq!(decoded.len(), pos.len()); + let batch = match self.is_field { + true => RecordBatch::try_new(self.schema.clone(), vec![make_array(decoded)])?, + false => RecordBatch::from(StructArray::from(decoded)) + .with_schema(self.schema.clone())?, + }; - let batch = RecordBatch::from(StructArray::from(decoded)) - .with_schema(self.schema.clone())?; Ok(Some(batch)) } } @@ -2175,4 +2227,16 @@ mod tests { let values = batch.column(0).as_primitive::(); assert_eq!(values.values(), &[1681319393, -7200]); } + + #[test] + fn test_serde_field() { + let field = Field::new("int", DataType::Int32, true); + let mut decoder = ReaderBuilder::new_with_field(field) + .build_decoder() + .unwrap(); + decoder.serialize(&[1_i32, 2, 3, 4]).unwrap(); + let b = decoder.flush().unwrap().unwrap(); + let values = b.column(0).as_primitive::().values(); + assert_eq!(values, &[1, 2, 3, 4]); + } } diff --git a/arrow-json/src/reader/tape.rs b/arrow-json/src/reader/tape.rs index 801e8f29d525..b39caede7047 100644 --- a/arrow-json/src/reader/tape.rs +++ b/arrow-json/src/reader/tape.rs @@ -297,7 +297,8 @@ macro_rules! next { pub struct TapeDecoder { elements: Vec, - num_rows: usize, + /// The number of rows decoded, including any in progress if `!stack.is_empty()` + cur_row: usize, /// Number of rows to read per batch batch_size: usize, @@ -330,36 +331,34 @@ impl TapeDecoder { offsets, elements, batch_size, - num_rows: 0, + cur_row: 0, bytes: Vec::with_capacity(num_fields * 2 * 8), stack: Vec::with_capacity(10), } } pub fn decode(&mut self, buf: &[u8]) -> Result { - if self.num_rows >= self.batch_size { - return Ok(0); - } - let mut iter = BufIter::new(buf); while !iter.is_empty() { - match self.stack.last_mut() { - // Start of row + let state = match self.stack.last_mut() { + Some(l) => l, None => { - // Skip over leading whitespace iter.skip_whitespace(); - match next!(iter) { - b'{' => { - let idx = self.elements.len() as u32; - self.stack.push(DecoderState::Object(idx)); - self.elements.push(TapeElement::StartObject(u32::MAX)); - } - b => return Err(err(b, "trimming leading whitespace")), + if iter.is_empty() || self.cur_row >= self.batch_size { + break; } + + // Start of row + self.cur_row += 1; + self.stack.push(DecoderState::Value); + self.stack.last_mut().unwrap() } + }; + + match state { // Decoding an object - Some(DecoderState::Object(start_idx)) => { + DecoderState::Object(start_idx) => { iter.advance_until(|b| !json_whitespace(b) && b != b','); match next!(iter) { b'"' => { @@ -374,16 +373,12 @@ impl TapeDecoder { TapeElement::StartObject(end_idx); self.elements.push(TapeElement::EndObject(start_idx)); self.stack.pop(); - self.num_rows += self.stack.is_empty() as usize; - if self.num_rows >= self.batch_size { - break; - } } b => return Err(err(b, "parsing object")), } } // Decoding a list - Some(DecoderState::List(start_idx)) => { + DecoderState::List(start_idx) => { iter.advance_until(|b| !json_whitespace(b) && b != b','); match iter.peek() { Some(b']') => { @@ -400,7 +395,7 @@ impl TapeDecoder { } } // Decoding a string - Some(DecoderState::String) => { + DecoderState::String => { let s = iter.advance_until(|b| matches!(b, b'\\' | b'"')); self.bytes.extend_from_slice(s); @@ -415,7 +410,7 @@ impl TapeDecoder { b => unreachable!("{}", b), } } - Some(state @ DecoderState::Value) => { + state @ DecoderState::Value => { iter.skip_whitespace(); *state = match next!(iter) { b'"' => DecoderState::String, @@ -439,7 +434,7 @@ impl TapeDecoder { b => return Err(err(b, "parsing value")), }; } - Some(DecoderState::Number) => { + DecoderState::Number => { let s = iter.advance_until(|b| { !matches!(b, b'0'..=b'9' | b'-' | b'+' | b'.' | b'e' | b'E') }); @@ -452,14 +447,14 @@ impl TapeDecoder { self.offsets.push(self.bytes.len()); } } - Some(DecoderState::Colon) => { + DecoderState::Colon => { iter.skip_whitespace(); match next!(iter) { b':' => self.stack.pop(), b => return Err(err(b, "parsing colon")), }; } - Some(DecoderState::Literal(literal, idx)) => { + DecoderState::Literal(literal, idx) => { let bytes = literal.bytes(); let expected = bytes.iter().skip(*idx as usize).copied(); for (expected, b) in expected.zip(&mut iter) { @@ -474,7 +469,7 @@ impl TapeDecoder { self.elements.push(element); } } - Some(DecoderState::Escape) => { + DecoderState::Escape => { let v = match next!(iter) { b'u' => { self.stack.pop(); @@ -496,7 +491,7 @@ impl TapeDecoder { self.bytes.push(v); } // Parse a unicode escape sequence - Some(DecoderState::Unicode(high, low, idx)) => loop { + DecoderState::Unicode(high, low, idx) => loop { match *idx { 0..=3 => *high = *high << 4 | parse_hex(next!(iter))? as u16, 4 => { @@ -547,7 +542,7 @@ impl TapeDecoder { .try_for_each(|row| row.serialize(&mut serializer)) .map_err(|e| ArrowError::JsonError(e.to_string()))?; - self.num_rows += rows.len(); + self.cur_row += rows.len(); Ok(()) } @@ -591,7 +586,7 @@ impl TapeDecoder { strings, elements: &self.elements, string_offsets: &self.offsets, - num_rows: self.num_rows, + num_rows: self.cur_row, }) } @@ -599,7 +594,7 @@ impl TapeDecoder { pub fn clear(&mut self) { assert!(self.stack.is_empty()); - self.num_rows = 0; + self.cur_row = 0; self.bytes.clear(); self.elements.clear(); self.elements.push(TapeElement::Null); @@ -837,7 +832,7 @@ mod tests { let err = decoder.decode(b"hello").unwrap_err().to_string(); assert_eq!( err, - "Json error: Encountered unexpected 'h' whilst trimming leading whitespace" + "Json error: Encountered unexpected 'h' whilst parsing value" ); let mut decoder = TapeDecoder::new(16, 2); diff --git a/arrow-json/src/writer.rs b/arrow-json/src/writer.rs index db371b59080a..8c4145bc95b4 100644 --- a/arrow-json/src/writer.rs +++ b/arrow-json/src/writer.rs @@ -1338,11 +1338,7 @@ mod tests { let batch = reader.next().unwrap().unwrap(); - let list_row = batch - .column(0) - .as_any() - .downcast_ref::() - .unwrap(); + let list_row = batch.column(0).as_list::(); let values = list_row.values(); assert_eq!(values.len(), 4); assert_eq!(values.null_count(), 1); diff --git a/arrow-string/Cargo.toml b/arrow-string/Cargo.toml index e1163dc03eab..1ae7af8bdf41 100644 --- a/arrow-string/Cargo.toml +++ b/arrow-string/Cargo.toml @@ -40,5 +40,5 @@ arrow-schema = { workspace = true } arrow-array = { workspace = true } arrow-select = { workspace = true } regex = { version = "1.7.0", default-features = false, features = ["std", "unicode", "perf"] } -regex-syntax = { version = "0.7.1", default-features = false, features = ["unicode"] } +regex-syntax = { version = "0.8.0", default-features = false, features = ["unicode"] } num = { version = "0.4", default-features = false, features = ["std"] } diff --git a/arrow/tests/csv.rs b/arrow/tests/csv.rs index 3ee319101757..a79b6b44c2d3 100644 --- a/arrow/tests/csv.rs +++ b/arrow/tests/csv.rs @@ -53,48 +53,6 @@ fn test_export_csv_timestamps() { } drop(writer); - let left = "c1,c2 -2019-04-18T20:54:47.378000000+10:00,2019-04-18T10:54:47.378000000 -2021-10-30T17:59:07.000000000+11:00,2021-10-30T06:59:07.000000000\n"; - let right = String::from_utf8(sw).unwrap(); - assert_eq!(left, right); -} - -#[test] -fn test_export_csv_timestamps_using_rfc3339() { - let schema = Schema::new(vec![ - Field::new( - "c1", - DataType::Timestamp(TimeUnit::Millisecond, Some("Australia/Sydney".into())), - true, - ), - Field::new("c2", DataType::Timestamp(TimeUnit::Millisecond, None), true), - ]); - - let c1 = TimestampMillisecondArray::from( - // 1555584887 converts to 2019-04-18, 20:54:47 in time zone Australia/Sydney (AEST). - // The offset (difference to UTC) is +10:00. - // 1635577147 converts to 2021-10-30 17:59:07 in time zone Australia/Sydney (AEDT) - // The offset (difference to UTC) is +11:00. Note that daylight savings is in effect on 2021-10-30. - // - vec![Some(1555584887378), Some(1635577147000)], - ) - .with_timezone("Australia/Sydney"); - let c2 = - TimestampMillisecondArray::from(vec![Some(1555584887378), Some(1635577147000)]); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); - - let mut sw = Vec::new(); - let mut writer = arrow_csv::WriterBuilder::new() - .with_rfc3339() - .build(&mut sw); - let batches = vec![&batch]; - for batch in batches { - writer.write(batch).unwrap(); - } - drop(writer); - let left = "c1,c2 2019-04-18T20:54:47.378+10:00,2019-04-18T10:54:47.378 2021-10-30T17:59:07+11:00,2021-10-30T06:59:07\n"; diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs index 1c35586f8bc9..8a45a9f3ac47 100644 --- a/object_store/src/aws/client.rs +++ b/object_store/src/aws/client.rs @@ -207,12 +207,13 @@ pub struct S3Config { pub retry_config: RetryConfig, pub client_options: ClientOptions, pub sign_payload: bool, + pub skip_signature: bool, pub checksum: Option, pub copy_if_not_exists: Option, } impl S3Config { - fn path_url(&self, path: &Path) -> String { + pub(crate) fn path_url(&self, path: &Path) -> String { format!("{}/{}", self.bucket_endpoint, encode_path(path)) } } @@ -234,8 +235,11 @@ impl S3Client { &self.config } - async fn get_credential(&self) -> Result> { - self.config.credentials.get_credential().await + async fn get_credential(&self) -> Result>> { + Ok(match self.config.skip_signature { + false => Some(self.config.credentials.get_credential().await?), + true => None, + }) } /// Make an S3 PUT request @@ -271,7 +275,7 @@ impl S3Client { let response = builder .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -299,7 +303,7 @@ impl S3Client { .request(Method::DELETE, url) .query(query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -390,7 +394,7 @@ impl S3Client { .header(CONTENT_TYPE, "application/xml") .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -459,7 +463,7 @@ impl S3Client { builder .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -490,7 +494,7 @@ impl S3Client { .client .request(Method::POST, url) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -535,7 +539,7 @@ impl S3Client { .query(&[("uploadId", upload_id)]) .body(body) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -554,15 +558,10 @@ impl GetClient for S3Client { const STORE: &'static str = STORE; /// Make an S3 GET request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -572,7 +571,7 @@ impl GetClient for S3Client { let response = builder .with_get_options(options) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, @@ -626,7 +625,7 @@ impl ListClient for S3Client { .request(Method::GET, &url) .query(&query) .with_aws_sigv4( - credential.as_ref(), + credential.as_deref(), &self.config.region, "s3", self.config.sign_payload, diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs index be0ffa578d13..e0c5de5fe784 100644 --- a/object_store/src/aws/credential.rs +++ b/object_store/src/aws/credential.rs @@ -30,7 +30,7 @@ use reqwest::{Client, Method, Request, RequestBuilder, StatusCode}; use serde::Deserialize; use std::collections::BTreeMap; use std::sync::Arc; -use std::time::Instant; +use std::time::{Duration, Instant}; use tracing::warn; use url::Url; @@ -89,6 +89,7 @@ const DATE_HEADER: &str = "x-amz-date"; const HASH_HEADER: &str = "x-amz-content-sha256"; const TOKEN_HEADER: &str = "x-amz-security-token"; const AUTH_HEADER: &str = "authorization"; +const ALGORITHM: &str = "AWS4-HMAC-SHA256"; impl<'a> AwsAuthorizer<'a> { /// Create a new [`AwsAuthorizer`] @@ -154,21 +155,110 @@ impl<'a> AwsAuthorizer<'a> { let header_digest = HeaderValue::from_str(&digest).unwrap(); request.headers_mut().insert(HASH_HEADER, header_digest); - // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets URI-encoded once). + let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); + + let scope = self.scope(date); + + let string_to_sign = self.string_to_sign( + date, + &scope, + request.method(), + request.url(), + &canonical_headers, + &signed_headers, + &digest, + ); + + // sign the string + let signature = + self.credential + .sign(&string_to_sign, date, self.region, self.service); + + // build the actual auth header + let authorisation = format!( + "{} Credential={}/{}, SignedHeaders={}, Signature={}", + ALGORITHM, self.credential.key_id, scope, signed_headers, signature + ); + + let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); + request.headers_mut().insert(AUTH_HEADER, authorization_val); + } + + pub(crate) fn sign(&self, method: Method, url: &mut Url, expires_in: Duration) { + let date = self.date.unwrap_or_else(Utc::now); + let scope = self.scope(date); + + // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + url.query_pairs_mut() + .append_pair("X-Amz-Algorithm", ALGORITHM) + .append_pair( + "X-Amz-Credential", + &format!("{}/{}", self.credential.key_id, scope), + ) + .append_pair("X-Amz-Date", &date.format("%Y%m%dT%H%M%SZ").to_string()) + .append_pair("X-Amz-Expires", &expires_in.as_secs().to_string()) + .append_pair("X-Amz-SignedHeaders", "host"); + + // For S3, you must include the X-Amz-Security-Token query parameter in the URL if + // using credentials sourced from the STS service. + if let Some(ref token) = self.credential.token { + url.query_pairs_mut() + .append_pair("X-Amz-Security-Token", token); + } + + // We don't have a payload; the user is going to send the payload directly themselves. + let digest = UNSIGNED_PAYLOAD; + + let host = &url[url::Position::BeforeHost..url::Position::AfterPort].to_string(); + let mut headers = HeaderMap::new(); + let host_val = HeaderValue::from_str(host).unwrap(); + headers.insert("host", host_val); + + let (signed_headers, canonical_headers) = canonicalize_headers(&headers); + + let string_to_sign = self.string_to_sign( + date, + &scope, + &method, + url, + &canonical_headers, + &signed_headers, + digest, + ); + + let signature = + self.credential + .sign(&string_to_sign, date, self.region, self.service); + + url.query_pairs_mut() + .append_pair("X-Amz-Signature", &signature); + } + + #[allow(clippy::too_many_arguments)] + fn string_to_sign( + &self, + date: DateTime, + scope: &str, + request_method: &Method, + url: &Url, + canonical_headers: &str, + signed_headers: &str, + digest: &str, + ) -> String { + // Each path segment must be URI-encoded twice (except for Amazon S3 which only gets + // URI-encoded once). // see https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html let canonical_uri = match self.service { - "s3" => request.url().path().to_string(), - _ => utf8_percent_encode(request.url().path(), &STRICT_PATH_ENCODE_SET) - .to_string(), + "s3" => url.path().to_string(), + _ => utf8_percent_encode(url.path(), &STRICT_PATH_ENCODE_SET).to_string(), }; - let (signed_headers, canonical_headers) = canonicalize_headers(request.headers()); - let canonical_query = canonicalize_query(request.url()); + let canonical_query = canonicalize_query(url); // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html let canonical_request = format!( "{}\n{}\n{}\n{}\n{}\n{}", - request.method().as_str(), + request_method.as_str(), canonical_uri, canonical_query, canonical_headers, @@ -177,33 +267,23 @@ impl<'a> AwsAuthorizer<'a> { ); let hashed_canonical_request = hex_digest(canonical_request.as_bytes()); - let scope = format!( - "{}/{}/{}/aws4_request", - date.format("%Y%m%d"), - self.region, - self.service - ); - let string_to_sign = format!( - "AWS4-HMAC-SHA256\n{}\n{}\n{}", + format!( + "{}\n{}\n{}\n{}", + ALGORITHM, date.format("%Y%m%dT%H%M%SZ"), scope, hashed_canonical_request - ); - - // sign the string - let signature = - self.credential - .sign(&string_to_sign, date, self.region, self.service); - - // build the actual auth header - let authorisation = format!( - "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}", - self.credential.key_id, scope, signed_headers, signature - ); + ) + } - let authorization_val = HeaderValue::from_str(&authorisation).unwrap(); - request.headers_mut().insert(AUTH_HEADER, authorization_val); + fn scope(&self, date: DateTime) -> String { + format!( + "{}/{}/{}/aws4_request", + date.format("%Y%m%d"), + self.region, + self.service + ) } } @@ -211,7 +291,7 @@ pub trait CredentialExt { /// Sign a request fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, @@ -222,20 +302,25 @@ pub trait CredentialExt { impl CredentialExt for RequestBuilder { fn with_aws_sigv4( self, - credential: &AwsCredential, + credential: Option<&AwsCredential>, region: &str, service: &str, sign_payload: bool, payload_sha256: Option<&[u8]>, ) -> Self { - let (client, request) = self.build_split(); - let mut request = request.expect("request valid"); + match credential { + Some(credential) => { + let (client, request) = self.build_split(); + let mut request = request.expect("request valid"); - AwsAuthorizer::new(credential, service, region) - .with_sign_payload(sign_payload) - .authorize(&mut request, payload_sha256); + AwsAuthorizer::new(credential, service, region) + .with_sign_payload(sign_payload) + .authorize(&mut request, payload_sha256); - Self::from_parts(client, request) + Self::from_parts(client, request) + } + None => self, + } } } @@ -667,7 +752,46 @@ mod tests { }; authorizer.authorize(&mut request, None); - assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699") + assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=653c3d8ea261fd826207df58bc2bb69fbb5003e9eb3c0ef06e4a51f2a81d8699"); + } + + #[test] + fn signed_get_url() { + // Values from https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html + let credential = AwsCredential { + key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + token: None, + }; + + let date = DateTime::parse_from_rfc3339("2013-05-24T00:00:00Z") + .unwrap() + .with_timezone(&Utc); + + let authorizer = AwsAuthorizer { + date: Some(date), + credential: &credential, + service: "s3", + region: "us-east-1", + sign_payload: false, + }; + + let mut url = + Url::parse("https://examplebucket.s3.amazonaws.com/test.txt").unwrap(); + authorizer.sign(Method::GET, &mut url, Duration::from_secs(86400)); + + assert_eq!( + url, + Url::parse( + "https://examplebucket.s3.amazonaws.com/test.txt?\ + X-Amz-Algorithm=AWS4-HMAC-SHA256&\ + X-Amz-Credential=AKIAIOSFODNN7EXAMPLE%2F20130524%2Fus-east-1%2Fs3%2Faws4_request&\ + X-Amz-Date=20130524T000000Z&\ + X-Amz-Expires=86400&\ + X-Amz-SignedHeaders=host&\ + X-Amz-Signature=aeeed9bbccd4d02ee5c0109b86d86835f995330da4c265957d157751f604d404" + ).unwrap() + ); } #[test] diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs index db3e1b9a4bbe..3ddce08002c4 100644 --- a/object_store/src/aws/mod.rs +++ b/object_store/src/aws/mod.rs @@ -36,10 +36,10 @@ use bytes::Bytes; use futures::stream::BoxStream; use futures::{StreamExt, TryStreamExt}; use itertools::Itertools; +use reqwest::Method; use serde::{Deserialize, Serialize}; use snafu::{ensure, OptionExt, ResultExt, Snafu}; -use std::str::FromStr; -use std::sync::Arc; +use std::{str::FromStr, sync::Arc, time::Duration}; use tokio::io::AsyncWrite; use tracing::info; use url::Url; @@ -56,6 +56,7 @@ use crate::client::{ }; use crate::config::ConfigValue; use crate::multipart::{PartId, PutPart, WriteMultiPart}; +use crate::signer::Signer; use crate::{ ClientOptions, GetOptions, GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result, RetryConfig, @@ -209,6 +210,65 @@ impl AmazonS3 { pub fn credentials(&self) -> &AwsCredentialProvider { &self.client.config().credentials } + + /// Create a full URL to the resource specified by `path` with this instance's configuration. + fn path_url(&self, path: &Path) -> String { + self.client.config().path_url(path) + } +} + +#[async_trait] +impl Signer for AmazonS3 { + /// Create a URL containing the relevant [AWS SigV4] query parameters that authorize a request + /// via `method` to the resource at `path` valid for the duration specified in `expires_in`. + /// + /// [AWS SigV4]: https://docs.aws.amazon.com/IAM/latest/UserGuide/create-signed-request.html + /// + /// # Example + /// + /// This example returns a URL that will enable a user to upload a file to + /// "some-folder/some-file.txt" in the next hour. + /// + /// ``` + /// # async fn example() -> Result<(), Box> { + /// # use object_store::{aws::AmazonS3Builder, path::Path, signer::Signer}; + /// # use reqwest::Method; + /// # use std::time::Duration; + /// # + /// let region = "us-east-1"; + /// let s3 = AmazonS3Builder::new() + /// .with_region(region) + /// .with_bucket_name("my-bucket") + /// .with_access_key_id("my-access-key-id") + /// .with_secret_access_key("my-secret-access-key") + /// .build()?; + /// + /// let url = s3.signed_url( + /// Method::PUT, + /// &Path::from("some-folder/some-file.txt"), + /// Duration::from_secs(60 * 60) + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + async fn signed_url( + &self, + method: Method, + path: &Path, + expires_in: Duration, + ) -> Result { + let credential = self.credentials().get_credential().await?; + let authorizer = + AwsAuthorizer::new(&credential, "s3", &self.client.config().region); + + let path_url = self.path_url(path); + let mut url = + Url::parse(&path_url).context(UnableToParseUrlSnafu { url: path_url })?; + + authorizer.sign(method, &mut url, expires_in); + + Ok(url) + } } #[async_trait] @@ -247,10 +307,6 @@ impl ObjectStore for AmazonS3 { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } @@ -392,6 +448,8 @@ pub struct AmazonS3Builder { client_options: ClientOptions, /// Credentials credentials: Option, + /// Skip signing requests + skip_signature: ConfigValue, /// Copy if not exists copy_if_not_exists: Option>, } @@ -530,6 +588,9 @@ pub enum AmazonS3ConfigKey { /// See [`S3CopyIfNotExists`] CopyIfNotExists, + /// Skip signing request + SkipSignature, + /// Client options Client(ClientConfigKey), } @@ -552,6 +613,7 @@ impl AsRef for AmazonS3ConfigKey { Self::ContainerCredentialsRelativeUri => { "aws_container_credentials_relative_uri" } + Self::SkipSignature => "aws_skip_signature", Self::CopyIfNotExists => "copy_if_not_exists", Self::Client(opt) => opt.as_ref(), } @@ -586,6 +648,7 @@ impl FromStr for AmazonS3ConfigKey { "aws_container_credentials_relative_uri" => { Ok(Self::ContainerCredentialsRelativeUri) } + "aws_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), "copy_if_not_exists" => Ok(Self::CopyIfNotExists), // Backwards compatibility "aws_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), @@ -697,6 +760,7 @@ impl AmazonS3Builder { AmazonS3ConfigKey::Client(key) => { self.client_options = self.client_options.with_config(key, value) } + AmazonS3ConfigKey::SkipSignature => self.skip_signature.parse(value), AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists = Some(ConfigValue::Deferred(value.into())) } @@ -767,6 +831,7 @@ impl AmazonS3Builder { AmazonS3ConfigKey::ContainerCredentialsRelativeUri => { self.container_credentials_relative_uri.clone() } + AmazonS3ConfigKey::SkipSignature => Some(self.skip_signature.to_string()), AmazonS3ConfigKey::CopyIfNotExists => { self.copy_if_not_exists.as_ref().map(ToString::to_string) } @@ -921,6 +986,14 @@ impl AmazonS3Builder { self } + /// If enabled, [`AmazonS3`] will not fetch credentials and will not sign requests + /// + /// This can be useful when interacting with public S3 buckets that deny authorized requests + pub fn with_skip_signature(mut self, skip_signature: bool) -> Self { + self.skip_signature = skip_signature.into(); + self + } + /// Sets the [checksum algorithm] which has to be used for object integrity check during upload. /// /// [checksum algorithm]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/checking-object-integrity.html @@ -1057,8 +1130,7 @@ impl AmazonS3Builder { Arc::new(TokenCredentialProvider::new( token, - // The instance metadata endpoint is access over HTTP - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; @@ -1090,6 +1162,7 @@ impl AmazonS3Builder { retry_config: self.retry_config, client_options: self.client_options, sign_payload: !self.unsigned_payload.get()?, + skip_signature: self.skip_signature.get()?, checksum, copy_if_not_exists, }; @@ -1449,4 +1522,30 @@ mod s3_resolve_bucket_region_tests { assert!(result.is_err()); } + + #[tokio::test] + #[ignore = "Tests shouldn't call use remote services by default"] + async fn test_disable_creds() { + // https://registry.opendata.aws/daylight-osm/ + let v1 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_access_key_id("local") + .with_secret_access_key("development") + .build() + .unwrap(); + + let prefix = Path::from("release"); + + v1.list_with_delimiter(Some(&prefix)).await.unwrap_err(); + + let v2 = AmazonS3Builder::new() + .with_bucket_name("daylight-map-distribution") + .with_region("us-west-1") + .with_skip_signature(true) + .build() + .unwrap(); + + v2.list_with_delimiter(Some(&prefix)).await.unwrap(); + } } diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs index cd1a3a10fcc7..f65388b61a80 100644 --- a/object_store/src/azure/client.rs +++ b/object_store/src/azure/client.rs @@ -264,15 +264,10 @@ impl GetClient for AzureClient { /// Make an Azure GET request /// /// - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.config.path_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs index b210d486d9bf..190b73bf9490 100644 --- a/object_store/src/azure/mod.rs +++ b/object_store/src/azure/mod.rs @@ -202,10 +202,6 @@ impl ObjectStore for MicrosoftAzure { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location, &()).await } @@ -1074,7 +1070,7 @@ impl MicrosoftAzureBuilder { ); Arc::new(TokenCredentialProvider::new( msi_credential, - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; diff --git a/object_store/src/client/get.rs b/object_store/src/client/get.rs index 333f6fe58475..7f68b6d1225f 100644 --- a/object_store/src/client/get.rs +++ b/object_store/src/client/get.rs @@ -17,7 +17,7 @@ use crate::client::header::{header_meta, HeaderConfig}; use crate::path::Path; -use crate::{Error, GetOptions, GetResult, ObjectMeta}; +use crate::{Error, GetOptions, GetResult}; use crate::{GetResultPayload, Result}; use async_trait::async_trait; use futures::{StreamExt, TryStreamExt}; @@ -34,27 +34,20 @@ pub trait GetClient: Send + Sync + 'static { last_modified_required: true, }; - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result; + async fn get_request(&self, path: &Path, options: GetOptions) -> Result; } /// Extension trait for [`GetClient`] that adds common retrieval functionality #[async_trait] pub trait GetClientExt { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result; - - async fn head(&self, location: &Path) -> Result; } #[async_trait] impl GetClientExt for T { async fn get_opts(&self, location: &Path, options: GetOptions) -> Result { let range = options.range.clone(); - let response = self.get_request(location, options, false).await?; + let response = self.get_request(location, options).await?; let meta = header_meta(location, response.headers(), T::HEADER_CONFIG).map_err(|e| { Error::Generic { @@ -77,15 +70,4 @@ impl GetClientExt for T { meta, }) } - - async fn head(&self, location: &Path) -> Result { - let options = GetOptions::default(); - let response = self.get_request(location, options, true).await?; - header_meta(location, response.headers(), T::HEADER_CONFIG).map_err(|e| { - Error::Generic { - store: T::STORE, - source: Box::new(e), - } - }) - } } diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs index ee9d62a44f0c..137da2b37594 100644 --- a/object_store/src/client/mod.rs +++ b/object_store/src/client/mod.rs @@ -166,7 +166,7 @@ impl FromStr for ClientConfigKey { } /// HTTP client configuration for remote object stores -#[derive(Debug, Clone, Default)] +#[derive(Debug, Clone)] pub struct ClientOptions { user_agent: Option>, content_type_map: HashMap, @@ -188,6 +188,35 @@ pub struct ClientOptions { http2_only: ConfigValue, } +impl Default for ClientOptions { + fn default() -> Self { + // Defaults based on + // + // + // Which recommend a connection timeout of 3.1s and a request timeout of 2s + Self { + user_agent: None, + content_type_map: Default::default(), + default_content_type: None, + default_headers: None, + proxy_url: None, + proxy_ca_certificate: None, + proxy_excludes: None, + allow_http: Default::default(), + allow_insecure: Default::default(), + timeout: Some(Duration::from_secs(5).into()), + connect_timeout: Some(Duration::from_secs(5).into()), + pool_idle_timeout: None, + pool_max_idle_per_host: None, + http2_keep_alive_interval: None, + http2_keep_alive_timeout: None, + http2_keep_alive_while_idle: Default::default(), + http1_only: Default::default(), + http2_only: Default::default(), + } + } +} + impl ClientOptions { /// Create a new [`ClientOptions`] with default values pub fn new() -> Self { @@ -367,17 +396,37 @@ impl ClientOptions { /// /// The timeout is applied from when the request starts connecting until the /// response body has finished + /// + /// Default is 5 seconds pub fn with_timeout(mut self, timeout: Duration) -> Self { self.timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the request timeout + /// + /// See [`Self::with_timeout`] + pub fn with_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set a timeout for only the connect phase of a Client + /// + /// Default is 5 seconds pub fn with_connect_timeout(mut self, timeout: Duration) -> Self { self.connect_timeout = Some(ConfigValue::Parsed(timeout)); self } + /// Disables the connection timeout + /// + /// See [`Self::with_connect_timeout`] + pub fn with_connect_timeout_disabled(mut self) -> Self { + self.timeout = None; + self + } + /// Set the pool max idle timeout /// /// This is the length of time an idle connection will be kept alive @@ -444,7 +493,20 @@ impl ClientOptions { } } - pub(crate) fn client(&self) -> super::Result { + /// Create a [`Client`] with overrides optimised for metadata endpoint access + /// + /// In particular: + /// * Allows HTTP as metadata endpoints do not use TLS + /// * Configures a low connection timeout to provide quick feedback if not present + #[cfg(any(feature = "aws", feature = "gcp", feature = "azure"))] + pub(crate) fn metadata_client(&self) -> Result { + self.clone() + .with_allow_http(true) + .with_connect_timeout(Duration::from_secs(1)) + .client() + } + + pub(crate) fn client(&self) -> Result { let mut builder = ClientBuilder::new(); match &self.user_agent { diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs index 39a913142e09..e4d246c87a2a 100644 --- a/object_store/src/client/retry.rs +++ b/object_store/src/client/retry.rs @@ -23,46 +23,50 @@ use futures::FutureExt; use reqwest::header::LOCATION; use reqwest::{Response, StatusCode}; use snafu::Error as SnafuError; +use snafu::Snafu; use std::time::{Duration, Instant}; use tracing::info; /// Retry request error -#[derive(Debug)] -pub struct Error { - retries: usize, - message: String, - source: Option, - status: Option, -} - -impl std::fmt::Display for Error { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "response error \"{}\", after {} retries", - self.message, self.retries - )?; - if let Some(source) = &self.source { - write!(f, ": {source}")?; - } - Ok(()) - } -} - -impl std::error::Error for Error { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.source.as_ref().map(|e| e as _) - } +#[derive(Debug, Snafu)] +pub enum Error { + #[snafu(display("Received redirect without LOCATION, this normally indicates an incorrectly configured region"))] + BareRedirect, + + #[snafu(display("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body")))] + Client { + status: StatusCode, + body: Option, + }, + + #[snafu(display("Error after {retries} retries: {source}"))] + Reqwest { + retries: usize, + source: reqwest::Error, + }, } impl Error { /// Returns the status code associated with this error if any pub fn status(&self) -> Option { - self.status + match self { + Self::BareRedirect => None, + Self::Client { status, .. } => Some(*status), + Self::Reqwest { source, .. } => source.status(), + } + } + + /// Returns the error body if any + pub fn body(&self) -> Option<&str> { + match self { + Self::Client { body, .. } => body.as_deref(), + Self::BareRedirect => None, + Self::Reqwest { .. } => None, + } } pub fn error(self, store: &'static str, path: String) -> crate::Error { - match self.status { + match self.status() { Some(StatusCode::NOT_FOUND) => crate::Error::NotFound { path, source: Box::new(self), @@ -86,16 +90,19 @@ impl Error { impl From for std::io::Error { fn from(err: Error) -> Self { use std::io::ErrorKind; - match (&err.source, err.status()) { - (Some(source), _) if source.is_builder() || source.is_request() => { - Self::new(ErrorKind::InvalidInput, err) - } - (_, Some(StatusCode::NOT_FOUND)) => Self::new(ErrorKind::NotFound, err), - (_, Some(StatusCode::BAD_REQUEST)) => Self::new(ErrorKind::InvalidInput, err), - (Some(source), None) if source.is_timeout() => { + match &err { + Error::Client { + status: StatusCode::NOT_FOUND, + .. + } => Self::new(ErrorKind::NotFound, err), + Error::Client { + status: StatusCode::BAD_REQUEST, + .. + } => Self::new(ErrorKind::InvalidInput, err), + Error::Reqwest { source, .. } if source.is_timeout() => { Self::new(ErrorKind::TimedOut, err) } - (Some(source), None) if source.is_connect() => { + Error::Reqwest { source, .. } if source.is_connect() => { Self::new(ErrorKind::NotConnected, err) } _ => Self::new(ErrorKind::Other, err), @@ -169,27 +176,21 @@ impl RetryExt for reqwest::RequestBuilder { Ok(r) => match r.error_for_status_ref() { Ok(_) if r.status().is_success() => return Ok(r), Ok(r) if r.status() == StatusCode::NOT_MODIFIED => { - return Err(Error{ - message: "not modified".to_string(), - retries, - status: Some(r.status()), - source: None, + return Err(Error::Client { + body: None, + status: StatusCode::NOT_MODIFIED, }) } Ok(r) => { let is_bare_redirect = r.status().is_redirection() && !r.headers().contains_key(LOCATION); - let message = match is_bare_redirect { - true => "Received redirect without LOCATION, this normally indicates an incorrectly configured region".to_string(), + return match is_bare_redirect { + true => Err(Error::BareRedirect), // Not actually sure if this is reachable, but here for completeness - false => format!("request unsuccessful: {}", r.status()), - }; - - return Err(Error{ - message, - retries, - status: Some(r.status()), - source: None, - }) + false => Err(Error::Client { + body: None, + status: r.status(), + }) + } } Err(e) => { let status = r.status(); @@ -198,23 +199,26 @@ impl RetryExt for reqwest::RequestBuilder { || now.elapsed() > retry_timeout || !status.is_server_error() { - // Get the response message if returned a client error - let message = match status.is_client_error() { + return Err(match status.is_client_error() { true => match r.text().await { - Ok(message) if !message.is_empty() => message, - Ok(_) => "No Body".to_string(), - Err(e) => format!("error getting response body: {e}") + Ok(body) => { + Error::Client { + body: Some(body).filter(|b| !b.is_empty()), + status, + } + } + Err(e) => { + Error::Reqwest { + retries, + source: e, + } + } } - false => status.to_string(), - }; - - return Err(Error{ - message, - retries, - status: Some(status), - source: Some(e), - }) - + false => Error::Reqwest { + retries, + source: e, + } + }); } let sleep = backoff.next(); @@ -238,16 +242,14 @@ impl RetryExt for reqwest::RequestBuilder { || now.elapsed() > retry_timeout || !do_retry { - return Err(Error{ + return Err(Error::Reqwest { retries, - message: "request error".to_string(), - status: e.status(), - source: Some(e), + source: e, }) } let sleep = backoff.next(); retries += 1; - info!("Encountered request error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); + info!("Encountered transport error ({}) backing off for {} seconds, retry {} of {}", e, sleep.as_secs_f32(), retries, max_retries); tokio::time::sleep(sleep).await; } } @@ -260,7 +262,7 @@ impl RetryExt for reqwest::RequestBuilder { #[cfg(test)] mod tests { use crate::client::mock_server::MockServer; - use crate::client::retry::RetryExt; + use crate::client::retry::{Error, RetryExt}; use crate::RetryConfig; use hyper::header::LOCATION; use hyper::{Body, Response}; @@ -294,8 +296,11 @@ mod tests { let e = do_request().await.unwrap_err(); assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "cupcakes"); + assert_eq!(e.body(), Some("cupcakes")); + assert_eq!( + e.to_string(), + "Client error with status 400 Bad Request: cupcakes" + ); // Handles client errors with no payload mock.push( @@ -307,8 +312,11 @@ mod tests { let e = do_request().await.unwrap_err(); assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST); - assert_eq!(e.retries, 0); - assert_eq!(&e.message, "No Body"); + assert_eq!(e.body(), None); + assert_eq!( + e.to_string(), + "Client error with status 400 Bad Request: No Body" + ); // Should retry server error request mock.push( @@ -381,7 +389,8 @@ mod tests { ); let e = do_request().await.unwrap_err(); - assert_eq!(e.message, "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); + assert!(matches!(e, Error::BareRedirect)); + assert_eq!(e.to_string(), "Received redirect without LOCATION, this normally indicates an incorrectly configured region"); // Gives up after the retrying the specified number of times for _ in 0..=retry.max_retries { @@ -393,22 +402,23 @@ mod tests { ); } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "502 Bad Gateway"); + let e = do_request().await.unwrap_err().to_string(); + assert!(e.starts_with("Error after 2 retries: HTTP status server error (502 Bad Gateway) for url"), "{e}"); // Panic results in an incomplete message error in the client mock.push_fn(|_| panic!()); let r = do_request().await.unwrap(); assert_eq!(r.status(), StatusCode::OK); - // Gives up after retrying mulitiple panics + // Gives up after retrying multiple panics for _ in 0..=retry.max_retries { mock.push_fn(|_| panic!()); } - let e = do_request().await.unwrap_err(); - assert_eq!(e.retries, retry.max_retries); - assert_eq!(e.message, "request error"); + let e = do_request().await.unwrap_err().to_string(); + assert!( + e.starts_with("Error after 2 retries: error sending request for url"), + "{e}" + ); // Shutdown mock.shutdown().await diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs index 3c583c67039f..a75527fe7b9f 100644 --- a/object_store/src/gcp/mod.rs +++ b/object_store/src/gcp/mod.rs @@ -387,16 +387,11 @@ impl GetClient for GoogleCloudStorageClient { const STORE: &'static str = STORE; /// Perform a get request - async fn get_request( - &self, - path: &Path, - options: GetOptions, - head: bool, - ) -> Result { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { let credential = self.get_credential().await?; let url = self.object_url(path); - let method = match head { + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -602,10 +597,6 @@ impl ObjectStore for GoogleCloudStorage { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete_request(location).await } @@ -1087,7 +1078,7 @@ impl GoogleCloudStorageBuilder { } else { Arc::new(TokenCredentialProvider::new( InstanceCredentialProvider::default(), - self.client_options.clone().with_allow_http(true).client()?, + self.client_options.metadata_client()?, self.retry_config.clone(), )) as _ }; @@ -1222,7 +1213,7 @@ mod test { .unwrap_err() .to_string(); assert!( - err.contains("HTTP status client error (404 Not Found)"), + err.contains("Client error with status 404 Not Found"), "{}", err ) diff --git a/object_store/src/http/client.rs b/object_store/src/http/client.rs index 0bd2e5639cb5..b2a6ac0aa34a 100644 --- a/object_store/src/http/client.rs +++ b/object_store/src/http/client.rs @@ -288,14 +288,9 @@ impl GetClient for Client { last_modified_required: false, }; - async fn get_request( - &self, - location: &Path, - options: GetOptions, - head: bool, - ) -> Result { - let url = self.path_url(location); - let method = match head { + async fn get_request(&self, path: &Path, options: GetOptions) -> Result { + let url = self.path_url(path); + let method = match options.head { true => Method::HEAD, false => Method::GET, }; @@ -311,7 +306,7 @@ impl GetClient for Client { Some(StatusCode::NOT_FOUND | StatusCode::METHOD_NOT_ALLOWED) => { crate::Error::NotFound { source: Box::new(source), - path: location.to_string(), + path: path.to_string(), } } _ => Error::Request { source }.into(), @@ -322,7 +317,7 @@ impl GetClient for Client { if has_range && res.status() != StatusCode::PARTIAL_CONTENT { return Err(crate::Error::NotSupported { source: Box::new(Error::RangeNotSupported { - href: location.to_string(), + href: path.to_string(), }), }); } diff --git a/object_store/src/http/mod.rs b/object_store/src/http/mod.rs index e9ed5902d8f5..6ffb62358941 100644 --- a/object_store/src/http/mod.rs +++ b/object_store/src/http/mod.rs @@ -118,10 +118,6 @@ impl ObjectStore for HttpStore { self.client.get_opts(location, options).await } - async fn head(&self, location: &Path) -> Result { - self.client.head(location).await - } - async fn delete(&self, location: &Path) -> Result<()> { self.client.delete(location).await } diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs index 3fd363fd4f06..ff0a46533dda 100644 --- a/object_store/src/lib.rs +++ b/object_store/src/lib.rs @@ -267,6 +267,8 @@ pub mod local; pub mod memory; pub mod path; pub mod prefix; +#[cfg(feature = "cloud")] +pub mod signer; pub mod throttle; #[cfg(feature = "cloud")] @@ -408,7 +410,13 @@ pub trait ObjectStore: std::fmt::Display + Send + Sync + Debug + 'static { } /// Return the metadata for the specified location - async fn head(&self, location: &Path) -> Result; + async fn head(&self, location: &Path) -> Result { + let options = GetOptions { + head: true, + ..Default::default() + }; + Ok(self.get_opts(location, options).await?.meta) + } /// Delete the object at the specified location. async fn delete(&self, location: &Path) -> Result<()>; @@ -714,6 +722,10 @@ pub struct GetOptions { /// /// pub range: Option>, + /// Request transfer of no content + /// + /// + pub head: bool, } impl GetOptions { diff --git a/object_store/src/local.rs b/object_store/src/local.rs index 69da170b0872..3ed63a410815 100644 --- a/object_store/src/local.rs +++ b/object_store/src/local.rs @@ -419,35 +419,6 @@ impl ObjectStore for LocalFileSystem { .await } - async fn head(&self, location: &Path) -> Result { - let path = self.config.path_to_filesystem(location)?; - let location = location.clone(); - - maybe_spawn_blocking(move || { - let metadata = match metadata(&path) { - Err(e) => Err(match e.kind() { - ErrorKind::NotFound => Error::NotFound { - path: path.clone(), - source: e, - }, - _ => Error::Metadata { - source: e.into(), - path: location.to_string(), - }, - }), - Ok(m) => match !m.is_dir() { - true => Ok(m), - false => Err(Error::NotFound { - path, - source: io::Error::new(ErrorKind::NotFound, "is directory"), - }), - }, - }?; - convert_metadata(metadata, location) - }) - .await - } - async fn delete(&self, location: &Path) -> Result<()> { let path = self.config.path_to_filesystem(location)?; maybe_spawn_blocking(move || match std::fs::remove_file(&path) { @@ -1604,15 +1575,15 @@ mod unix_test { let path = root.path().join(filename); unistd::mkfifo(&path, stat::Mode::S_IRWXU).unwrap(); - let location = Path::from(filename); - integration.head(&location).await.unwrap(); - // Need to open read and write side in parallel let spawned = tokio::task::spawn_blocking(|| { - OpenOptions::new().write(true).open(path).unwrap(); + OpenOptions::new().write(true).open(path).unwrap() }); + let location = Path::from(filename); + integration.head(&location).await.unwrap(); integration.get(&location).await.unwrap(); + spawned.await.unwrap(); } } diff --git a/object_store/src/signer.rs b/object_store/src/signer.rs new file mode 100644 index 000000000000..f1f35debe053 --- /dev/null +++ b/object_store/src/signer.rs @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Abstraction of signed URL generation for those object store implementations that support it + +use crate::{path::Path, Result}; +use async_trait::async_trait; +use reqwest::Method; +use std::{fmt, time::Duration}; +use url::Url; + +/// Universal API to presigned URLs generated from multiple object store services. Not supported by +/// all object store services. +#[async_trait] +pub trait Signer: Send + Sync + fmt::Debug + 'static { + /// Given the intended [`Method`] and [`Path`] to use and the desired length of time for which + /// the URL should be valid, return a signed [`Url`] created with the object store + /// implementation's credentials such that the URL can be handed to something that doesn't have + /// access to the object store's credentials, to allow limited access to the object store. + async fn signed_url( + &self, + method: Method, + path: &Path, + expires_in: Duration, + ) -> Result; +} diff --git a/parquet/CONTRIBUTING.md b/parquet/CONTRIBUTING.md index 903126d9f4f8..5670eef08101 100644 --- a/parquet/CONTRIBUTING.md +++ b/parquet/CONTRIBUTING.md @@ -62,10 +62,6 @@ To compile and view in the browser, run `cargo doc --no-deps --open`. ## Update Parquet Format -To generate the parquet format (thrift definitions) code run from the repository root run - -``` -$ docker run -v $(pwd):/thrift/src -it archlinux pacman -Sy --noconfirm thrift && wget https://raw.githubusercontent.com/apache/parquet-format/apache-parquet-format-2.9.0/src/main/thrift/parquet.thrift -O /tmp/parquet.thrift && thrift --gen rs /tmp/parquet.thrift && sed -i '/use thrift::server::TProcessor;/d' parquet.rs && mv parquet.rs parquet/src/format.rs -``` +To generate the parquet format (thrift definitions) code run [`./regen.sh`](./regen.sh). You may need to manually patch up doc comments that contain unescaped `[]` diff --git a/parquet/Cargo.toml b/parquet/Cargo.toml index c710c83213b9..659e2c0ee3a7 100644 --- a/parquet/Cargo.toml +++ b/parquet/Cargo.toml @@ -52,7 +52,7 @@ snap = { version = "1.0", default-features = false, optional = true } brotli = { version = "3.3", default-features = false, features = ["std"], optional = true } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } -zstd = { version = "0.12.0", optional = true, default-features = false } +zstd = { version = "0.13.0", optional = true, default-features = false } chrono = { workspace = true } num = { version = "0.4", default-features = false } num-bigint = { version = "0.4", default-features = false } @@ -75,7 +75,7 @@ tempfile = { version = "3.0", default-features = false } brotli = { version = "3.3", default-features = false, features = ["std"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"] } -zstd = { version = "0.12", default-features = false } +zstd = { version = "0.13", default-features = false } serde_json = { version = "1.0", features = ["std"], default-features = false } arrow = { workspace = true, features = ["ipc", "test_utils", "prettyprint", "json"] } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "io-util", "fs"] } @@ -173,5 +173,10 @@ name = "compression" required-features = ["experimental", "default"] harness = false + +[[bench]] +name = "metadata" +harness = false + [lib] bench = false diff --git a/parquet/benches/metadata.rs b/parquet/benches/metadata.rs new file mode 100644 index 000000000000..c817385f6ba9 --- /dev/null +++ b/parquet/benches/metadata.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use bytes::Bytes; +use criterion::*; +use parquet::file::reader::SerializedFileReader; +use parquet::file::serialized_reader::ReadOptionsBuilder; + +fn criterion_benchmark(c: &mut Criterion) { + // Read file into memory to isolate filesystem performance + let file = "../parquet-testing/data/alltypes_tiny_pages.parquet"; + let data = std::fs::read(file).unwrap(); + let data = Bytes::from(data); + + c.bench_function("open(default)", |b| { + b.iter(|| SerializedFileReader::new(data.clone()).unwrap()) + }); + + c.bench_function("open(page index)", |b| { + b.iter(|| { + let options = ReadOptionsBuilder::new().with_page_index().build(); + SerializedFileReader::new_with_options(data.clone(), options).unwrap() + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/parquet/regen.sh b/parquet/regen.sh new file mode 100755 index 000000000000..b8c3549e2324 --- /dev/null +++ b/parquet/regen.sh @@ -0,0 +1,35 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +REVISION=aeae80660c1d0c97314e9da837de1abdebd49c37 + +SOURCE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" && pwd)" + +docker run -v $SOURCE_DIR:/thrift/src -it archlinux pacman -Sy --noconfirm thrift && \ + wget https://raw.githubusercontent.com/apache/parquet-format/$REVISION/src/main/thrift/parquet.thrift -O /tmp/parquet.thrift && \ + thrift --gen rs /tmp/parquet.thrift && \ + echo "Removing TProcessor" && \ + sed -i '/use thrift::server::TProcessor;/d' parquet.rs && \ + echo "Replacing TSerializable" && \ + sed -i 's/impl TSerializable for/impl crate::thrift::TSerializable for/g' parquet.rs && \ + echo "Rewriting write_to_out_protocol" && \ + sed -i 's/fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol)/fn write_to_out_protocol(\&self, o_prot: \&mut T)/g' parquet.rs && \ + echo "Rewriting read_from_in_protocol" && \ + sed -i 's/fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol)/fn read_from_in_protocol(i_prot: \&mut T)/g' parquet.rs && \ + mv parquet.rs src/format.rs diff --git a/parquet/src/arrow/arrow_writer/mod.rs b/parquet/src/arrow/arrow_writer/mod.rs index 5dae81d4711c..752eff86c5e9 100644 --- a/parquet/src/arrow/arrow_writer/mod.rs +++ b/parquet/src/arrow/arrow_writer/mod.rs @@ -23,7 +23,7 @@ use std::iter::Peekable; use std::slice::Iter; use std::sync::{Arc, Mutex}; use std::vec::IntoIter; -use thrift::protocol::{TCompactOutputProtocol, TSerializable}; +use thrift::protocol::TCompactOutputProtocol; use arrow_array::cast::AsArray; use arrow_array::types::*; @@ -50,6 +50,7 @@ use crate::file::properties::{WriterProperties, WriterPropertiesPtr}; use crate::file::reader::{ChunkReader, Length}; use crate::file::writer::{SerializedFileWriter, SerializedRowGroupWriter}; use crate::schema::types::{ColumnDescPtr, SchemaDescriptor}; +use crate::thrift::TSerializable; use levels::{calculate_array_levels, ArrayLevels}; mod byte_array; diff --git a/parquet/src/arrow/async_reader/metadata.rs b/parquet/src/arrow/async_reader/metadata.rs index 076ae5c54052..fe7b4427647c 100644 --- a/parquet/src/arrow/async_reader/metadata.rs +++ b/parquet/src/arrow/async_reader/metadata.rs @@ -17,7 +17,7 @@ use crate::arrow::async_reader::AsyncFileReader; use crate::errors::{ParquetError, Result}; -use crate::file::footer::{decode_footer, read_metadata}; +use crate::file::footer::{decode_footer, decode_metadata}; use crate::file::metadata::ParquetMetaData; use crate::file::page_index::index::Index; use crate::file::page_index::index_reader::{ @@ -27,7 +27,6 @@ use bytes::Bytes; use futures::future::BoxFuture; use futures::FutureExt; use std::future::Future; -use std::io::Read; use std::ops::Range; /// A data source that can be used with [`MetadataLoader`] to load [`ParquetMetaData`] @@ -95,16 +94,14 @@ impl MetadataLoader { // Did not fetch the entire file metadata in the initial read, need to make a second request let (metadata, remainder) = if length > suffix_len - 8 { let metadata_start = file_size - length - 8; - let remaining_metadata = fetch.fetch(metadata_start..footer_start).await?; - - let reader = remaining_metadata.as_ref().chain(&suffix[..suffix_len - 8]); - (read_metadata(reader)?, None) + let meta = fetch.fetch(metadata_start..file_size - 8).await?; + (decode_metadata(&meta)?, None) } else { let metadata_start = file_size - length - 8 - footer_start; let slice = &suffix[metadata_start..suffix_len - 8]; ( - read_metadata(slice)?, + decode_metadata(slice)?, Some((footer_start, suffix.slice(..metadata_start))), ) }; diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index c749d4deeb16..4b3eebf2e67e 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -77,7 +77,6 @@ use std::collections::VecDeque; use std::fmt::Formatter; - use std::io::SeekFrom; use std::ops::Range; use std::pin::Pin; @@ -88,7 +87,6 @@ use bytes::{Buf, Bytes}; use futures::future::{BoxFuture, FutureExt}; use futures::ready; use futures::stream::Stream; - use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use arrow_array::RecordBatch; @@ -102,15 +100,18 @@ use crate::arrow::arrow_reader::{ }; use crate::arrow::ProjectionMask; +use crate::bloom_filter::{ + chunk_read_bloom_filter_header_and_offset, Sbbf, SBBF_HEADER_SIZE_ESTIMATE, +}; use crate::column::page::{PageIterator, PageReader}; - use crate::errors::{ParquetError, Result}; use crate::file::footer::{decode_footer, decode_metadata}; use crate::file::metadata::{ParquetMetaData, RowGroupMetaData}; use crate::file::reader::{ChunkReader, Length, SerializedPageReader}; -use crate::format::PageLocation; - use crate::file::FOOTER_SIZE; +use crate::format::{ + BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash, PageLocation, +}; mod metadata; pub use metadata::*; @@ -302,6 +303,71 @@ impl ParquetRecordBatchStreamBuilder { Self::new_builder(AsyncReader(input), metadata) } + /// Read bloom filter for a column in a row group + /// Returns `None` if the column does not have a bloom filter + /// + /// We should call this function after other forms pruning, such as projection and predicate pushdown. + pub async fn get_row_group_column_bloom_filter( + &mut self, + row_group_idx: usize, + column_idx: usize, + ) -> Result> { + let metadata = self.metadata.row_group(row_group_idx); + let column_metadata = metadata.column(column_idx); + + let offset: usize = if let Some(offset) = column_metadata.bloom_filter_offset() { + offset.try_into().map_err(|_| { + ParquetError::General("Bloom filter offset is invalid".to_string()) + })? + } else { + return Ok(None); + }; + + let buffer = match column_metadata.bloom_filter_length() { + Some(length) => self.input.0.get_bytes(offset..offset + length as usize), + None => self + .input + .0 + .get_bytes(offset..offset + SBBF_HEADER_SIZE_ESTIMATE), + } + .await?; + + let (header, bitset_offset) = + chunk_read_bloom_filter_header_and_offset(offset as u64, buffer.clone())?; + + match header.algorithm { + BloomFilterAlgorithm::BLOCK(_) => { + // this match exists to future proof the singleton algorithm enum + } + } + match header.compression { + BloomFilterCompression::UNCOMPRESSED(_) => { + // this match exists to future proof the singleton compression enum + } + } + match header.hash { + BloomFilterHash::XXHASH(_) => { + // this match exists to future proof the singleton hash enum + } + } + + let bitset = match column_metadata.bloom_filter_length() { + Some(_) => buffer.slice((bitset_offset as usize - offset)..), + None => { + let bitset_length: usize = header.num_bytes.try_into().map_err(|_| { + ParquetError::General("Bloom filter length is invalid".to_string()) + })?; + self.input + .0 + .get_bytes( + bitset_offset as usize..bitset_offset as usize + bitset_length, + ) + .await? + } + }; + Ok(Some(Sbbf::new(&bitset))) + } + /// Build a new [`ParquetRecordBatchStream`] pub fn build(self) -> Result> { let num_row_groups = self.metadata.row_groups().len(); @@ -1540,4 +1606,75 @@ mod tests { assert_ne!(1024, file_rows); assert_eq!(stream.batch_size, file_rows); } + + #[tokio::test] + async fn test_get_row_group_column_bloom_filter_without_length() { + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/data_index_bloom_encoding_stats.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + test_get_row_group_column_bloom_filter(data, false).await; + } + + #[tokio::test] + async fn test_get_row_group_column_bloom_filter_with_length() { + // convert to new parquet file with bloom_filter_length + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/data_index_bloom_encoding_stats.parquet"); + let data = Bytes::from(std::fs::read(path).unwrap()); + let metadata = parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + let builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + let schema = builder.schema().clone(); + let stream = builder.build().unwrap(); + let batches = stream.try_collect::>().await.unwrap(); + + let mut parquet_data = Vec::new(); + let props = WriterProperties::builder() + .set_bloom_filter_enabled(true) + .build(); + let mut writer = + ArrowWriter::try_new(&mut parquet_data, schema, Some(props)).unwrap(); + for batch in batches { + writer.write(&batch).unwrap(); + } + writer.close().unwrap(); + + // test the new parquet file + test_get_row_group_column_bloom_filter(parquet_data.into(), true).await; + } + + async fn test_get_row_group_column_bloom_filter(data: Bytes, with_length: bool) { + let metadata = parse_metadata(&data).unwrap(); + let metadata = Arc::new(metadata); + + assert_eq!(metadata.num_row_groups(), 1); + let row_group = metadata.row_group(0); + let column = row_group.column(0); + assert_eq!(column.bloom_filter_length().is_some(), with_length); + + let async_reader = TestReader { + data: data.clone(), + metadata: metadata.clone(), + requests: Default::default(), + }; + + let mut builder = ParquetRecordBatchStreamBuilder::new(async_reader) + .await + .unwrap(); + + let sbbf = builder + .get_row_group_column_bloom_filter(0, 0) + .await + .unwrap() + .unwrap(); + assert!(sbbf.check(&"Hello")); + assert!(!sbbf.check(&"Hello_Not_Exists")); + } } diff --git a/parquet/src/bin/parquet-layout.rs b/parquet/src/bin/parquet-layout.rs index d749bb8a4ba7..901ac9ea2309 100644 --- a/parquet/src/bin/parquet-layout.rs +++ b/parquet/src/bin/parquet-layout.rs @@ -38,12 +38,13 @@ use std::io::Read; use clap::Parser; use serde::Serialize; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use thrift::protocol::TCompactInputProtocol; use parquet::basic::{Compression, Encoding}; use parquet::errors::Result; use parquet::file::reader::ChunkReader; use parquet::format::PageHeader; +use parquet::thrift::TSerializable; #[derive(Serialize, Debug)] struct ParquetFile { diff --git a/parquet/src/bloom_filter/mod.rs b/parquet/src/bloom_filter/mod.rs index c893d492b52a..e98aee9fd213 100644 --- a/parquet/src/bloom_filter/mod.rs +++ b/parquet/src/bloom_filter/mod.rs @@ -26,13 +26,12 @@ use crate::format::{ BloomFilterAlgorithm, BloomFilterCompression, BloomFilterHash, BloomFilterHeader, SplitBlockAlgorithm, Uncompressed, XxHash, }; -use bytes::{Buf, Bytes}; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; +use bytes::Bytes; use std::hash::Hasher; use std::io::Write; use std::sync::Arc; -use thrift::protocol::{ - TCompactInputProtocol, TCompactOutputProtocol, TOutputProtocol, TSerializable, -}; +use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol}; use twox_hash::XxHash64; /// Salt as defined in the [spec](https://github.com/apache/parquet-format/blob/master/BloomFilter.md#technical-approach). @@ -133,11 +132,11 @@ impl std::ops::IndexMut for Block { #[derive(Debug, Clone)] pub struct Sbbf(Vec); -const SBBF_HEADER_SIZE_ESTIMATE: usize = 20; +pub(crate) const SBBF_HEADER_SIZE_ESTIMATE: usize = 20; /// given an initial offset, and a byte buffer, try to read out a bloom filter header and return /// both the header and the offset after it (for bitset). -fn chunk_read_bloom_filter_header_and_offset( +pub(crate) fn chunk_read_bloom_filter_header_and_offset( offset: u64, buffer: Bytes, ) -> Result<(BloomFilterHeader, u64), ParquetError> { @@ -148,19 +147,15 @@ fn chunk_read_bloom_filter_header_and_offset( /// given a [Bytes] buffer, try to read out a bloom filter header and return both the header and /// length of the header. #[inline] -fn read_bloom_filter_header_and_length( +pub(crate) fn read_bloom_filter_header_and_length( buffer: Bytes, ) -> Result<(BloomFilterHeader, u64), ParquetError> { let total_length = buffer.len(); - let mut buf_reader = buffer.reader(); - let mut prot = TCompactInputProtocol::new(&mut buf_reader); + let mut prot = TCompactSliceInputProtocol::new(buffer.as_ref()); let header = BloomFilterHeader::read_from_in_protocol(&mut prot).map_err(|e| { ParquetError::General(format!("Could not read bloom filter header: {e}")) })?; - Ok(( - header, - (total_length - buf_reader.into_inner().remaining()) as u64, - )) + Ok((header, (total_length - prot.as_slice().len()) as u64)) } pub(crate) const BITSET_MIN_LENGTH: usize = 32; @@ -204,7 +199,7 @@ impl Sbbf { Self::new(&bitset) } - fn new(bitset: &[u8]) -> Self { + pub(crate) fn new(bitset: &[u8]) -> Self { let data = bitset .chunks_exact(4 * 8) .map(|chunk| { diff --git a/parquet/src/file/footer.rs b/parquet/src/file/footer.rs index 21de63e0c234..53496a66b572 100644 --- a/parquet/src/file/footer.rs +++ b/parquet/src/file/footer.rs @@ -18,7 +18,7 @@ use std::{io::Read, sync::Arc}; use crate::format::{ColumnOrder as TColumnOrder, FileMetaData as TFileMetaData}; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use crate::basic::ColumnOrder; @@ -62,18 +62,13 @@ pub fn parse_metadata(chunk_reader: &R) -> Result Result { - read_metadata(metadata_read) -} - -/// Decodes [`ParquetMetaData`] from the provided [`Read`] -pub(crate) fn read_metadata(read: R) -> Result { +pub fn decode_metadata(buf: &[u8]) -> Result { // TODO: row group filtering - let mut prot = TCompactInputProtocol::new(read); + let mut prot = TCompactSliceInputProtocol::new(buf); let t_file_metadata: TFileMetaData = TFileMetaData::read_from_in_protocol(&mut prot) .map_err(|e| ParquetError::General(format!("Could not parse metadata: {e}")))?; let schema = types::from_thrift(&t_file_metadata.schema)?; diff --git a/parquet/src/file/page_index/index_reader.rs b/parquet/src/file/page_index/index_reader.rs index c36708a59aeb..ae3bf3699c1c 100644 --- a/parquet/src/file/page_index/index_reader.rs +++ b/parquet/src/file/page_index/index_reader.rs @@ -24,9 +24,8 @@ use crate::file::metadata::ColumnChunkMetaData; use crate::file::page_index::index::{Index, NativeIndex}; use crate::file::reader::ChunkReader; use crate::format::{ColumnIndex, OffsetIndex, PageLocation}; -use std::io::Cursor; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use std::ops::Range; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; /// Computes the covering range of two optional ranges /// @@ -116,7 +115,7 @@ pub fn read_pages_locations( pub(crate) fn decode_offset_index( data: &[u8], ) -> Result, ParquetError> { - let mut prot = TCompactInputProtocol::new(data); + let mut prot = TCompactSliceInputProtocol::new(data); let offset = OffsetIndex::read_from_in_protocol(&mut prot)?; Ok(offset.page_locations) } @@ -125,8 +124,7 @@ pub(crate) fn decode_column_index( data: &[u8], column_type: Type, ) -> Result { - let mut d = Cursor::new(data); - let mut prot = TCompactInputProtocol::new(&mut d); + let mut prot = TCompactSliceInputProtocol::new(data); let index = ColumnIndex::read_from_in_protocol(&mut prot)?; diff --git a/parquet/src/file/serialized_reader.rs b/parquet/src/file/serialized_reader.rs index 4924dcc6f35a..4bc484144a81 100644 --- a/parquet/src/file/serialized_reader.rs +++ b/parquet/src/file/serialized_reader.rs @@ -19,7 +19,6 @@ //! Also contains implementations of the ChunkReader for files (with buffering) and byte arrays (RAM) use std::collections::VecDeque; -use std::io::Cursor; use std::iter; use std::{convert::TryFrom, fs::File, io::Read, path::Path, sync::Arc}; @@ -40,8 +39,9 @@ use crate::format::{PageHeader, PageLocation, PageType}; use crate::record::reader::RowIter; use crate::record::Row; use crate::schema::types::Type as SchemaType; +use crate::thrift::{TCompactSliceInputProtocol, TSerializable}; use crate::util::memory::ByteBufferPtr; -use thrift::protocol::{TCompactInputProtocol, TSerializable}; +use thrift::protocol::TCompactInputProtocol; impl TryFrom for SerializedFileReader { type Error = ParquetError; @@ -661,11 +661,11 @@ impl PageReader for SerializedPageReader { let buffer = self.reader.get_bytes(front.offset as u64, page_len)?; - let mut cursor = Cursor::new(buffer.as_ref()); - let header = read_page_header(&mut cursor)?; - let offset = cursor.position(); + let mut prot = TCompactSliceInputProtocol::new(buffer.as_ref()); + let header = PageHeader::read_from_in_protocol(&mut prot)?; + let offset = buffer.len() - prot.as_slice().len(); - let bytes = buffer.slice(offset as usize..); + let bytes = buffer.slice(offset..); decode_page( header, bytes.into(), diff --git a/parquet/src/file/writer.rs b/parquet/src/file/writer.rs index d723158de9f4..7796be6013df 100644 --- a/parquet/src/file/writer.rs +++ b/parquet/src/file/writer.rs @@ -21,10 +21,11 @@ use crate::bloom_filter::Sbbf; use crate::format as parquet; use crate::format::{ColumnIndex, OffsetIndex, RowGroup}; +use crate::thrift::TSerializable; use std::fmt::Debug; use std::io::{BufWriter, IoSlice, Read}; use std::{io::Write, sync::Arc}; -use thrift::protocol::{TCompactOutputProtocol, TSerializable}; +use thrift::protocol::TCompactOutputProtocol; use crate::column::writer::{ get_typed_column_writer_mut, ColumnCloseResult, ColumnWriterImpl, diff --git a/parquet/src/format.rs b/parquet/src/format.rs index 12c572c23cf5..46adc39e6406 100644 --- a/parquet/src/format.rs +++ b/parquet/src/format.rs @@ -53,12 +53,12 @@ impl Type { ]; } -impl TSerializable for Type { +impl crate::thrift::TSerializable for Type { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(Type::from(enum_value)) } @@ -222,12 +222,12 @@ impl ConvertedType { ]; } -impl TSerializable for ConvertedType { +impl crate::thrift::TSerializable for ConvertedType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(ConvertedType::from(enum_value)) } @@ -299,12 +299,12 @@ impl FieldRepetitionType { ]; } -impl TSerializable for FieldRepetitionType { +impl crate::thrift::TSerializable for FieldRepetitionType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(FieldRepetitionType::from(enum_value)) } @@ -397,12 +397,12 @@ impl Encoding { ]; } -impl TSerializable for Encoding { +impl crate::thrift::TSerializable for Encoding { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(Encoding::from(enum_value)) } @@ -474,12 +474,12 @@ impl CompressionCodec { ]; } -impl TSerializable for CompressionCodec { +impl crate::thrift::TSerializable for CompressionCodec { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(CompressionCodec::from(enum_value)) } @@ -535,12 +535,12 @@ impl PageType { ]; } -impl TSerializable for PageType { +impl crate::thrift::TSerializable for PageType { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(PageType::from(enum_value)) } @@ -592,12 +592,12 @@ impl BoundaryOrder { ]; } -impl TSerializable for BoundaryOrder { +impl crate::thrift::TSerializable for BoundaryOrder { #[allow(clippy::trivially_copy_pass_by_ref)] - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { o_prot.write_i32(self.0) } - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let enum_value = i_prot.read_i32()?; Ok(BoundaryOrder::from(enum_value)) } @@ -678,8 +678,8 @@ impl Statistics { } } -impl TSerializable for Statistics { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for Statistics { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -735,7 +735,7 @@ impl TSerializable for Statistics { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("Statistics"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.max { @@ -788,8 +788,8 @@ impl StringType { } } -impl TSerializable for StringType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for StringType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -803,7 +803,7 @@ impl TSerializable for StringType { let ret = StringType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("StringType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -825,8 +825,8 @@ impl UUIDType { } } -impl TSerializable for UUIDType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for UUIDType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -840,7 +840,7 @@ impl TSerializable for UUIDType { let ret = UUIDType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("UUIDType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -862,8 +862,8 @@ impl MapType { } } -impl TSerializable for MapType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MapType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -877,7 +877,7 @@ impl TSerializable for MapType { let ret = MapType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MapType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -899,8 +899,8 @@ impl ListType { } } -impl TSerializable for ListType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ListType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -914,7 +914,7 @@ impl TSerializable for ListType { let ret = ListType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ListType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -936,8 +936,8 @@ impl EnumType { } } -impl TSerializable for EnumType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EnumType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -951,7 +951,7 @@ impl TSerializable for EnumType { let ret = EnumType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EnumType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -973,8 +973,8 @@ impl DateType { } } -impl TSerializable for DateType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DateType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -988,7 +988,7 @@ impl TSerializable for DateType { let ret = DateType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DateType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1015,8 +1015,8 @@ impl NullType { } } -impl TSerializable for NullType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for NullType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1030,7 +1030,7 @@ impl TSerializable for NullType { let ret = NullType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("NullType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1066,8 +1066,8 @@ impl DecimalType { } } -impl TSerializable for DecimalType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DecimalType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1101,7 +1101,7 @@ impl TSerializable for DecimalType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DecimalType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("scale", TType::I32, 1))?; @@ -1130,8 +1130,8 @@ impl MilliSeconds { } } -impl TSerializable for MilliSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MilliSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1145,7 +1145,7 @@ impl TSerializable for MilliSeconds { let ret = MilliSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MilliSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1167,8 +1167,8 @@ impl MicroSeconds { } } -impl TSerializable for MicroSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for MicroSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1182,7 +1182,7 @@ impl TSerializable for MicroSeconds { let ret = MicroSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("MicroSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1204,8 +1204,8 @@ impl NanoSeconds { } } -impl TSerializable for NanoSeconds { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for NanoSeconds { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1219,7 +1219,7 @@ impl TSerializable for NanoSeconds { let ret = NanoSeconds {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("NanoSeconds"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1238,8 +1238,8 @@ pub enum TimeUnit { NANOS(NanoSeconds), } -impl TSerializable for TimeUnit { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimeUnit { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -1301,7 +1301,7 @@ impl TSerializable for TimeUnit { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimeUnit"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -1348,8 +1348,8 @@ impl TimestampType { } } -impl TSerializable for TimestampType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimestampType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1383,7 +1383,7 @@ impl TSerializable for TimestampType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimestampType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("isAdjustedToUTC", TType::Bool, 1))?; @@ -1419,8 +1419,8 @@ impl TimeType { } } -impl TSerializable for TimeType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TimeType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1454,7 +1454,7 @@ impl TSerializable for TimeType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TimeType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("isAdjustedToUTC", TType::Bool, 1))?; @@ -1492,8 +1492,8 @@ impl IntType { } } -impl TSerializable for IntType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for IntType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1527,7 +1527,7 @@ impl TSerializable for IntType { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("IntType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("bitWidth", TType::I08, 1))?; @@ -1558,8 +1558,8 @@ impl JsonType { } } -impl TSerializable for JsonType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for JsonType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1573,7 +1573,7 @@ impl TSerializable for JsonType { let ret = JsonType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("JsonType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1598,8 +1598,8 @@ impl BsonType { } } -impl TSerializable for BsonType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BsonType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -1613,7 +1613,7 @@ impl TSerializable for BsonType { let ret = BsonType {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BsonType"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -1642,8 +1642,8 @@ pub enum LogicalType { UUID(UUIDType), } -impl TSerializable for LogicalType { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for LogicalType { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -1775,7 +1775,7 @@ impl TSerializable for LogicalType { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("LogicalType"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -1915,8 +1915,8 @@ impl SchemaElement { } } -impl TSerializable for SchemaElement { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SchemaElement { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -1997,7 +1997,7 @@ impl TSerializable for SchemaElement { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SchemaElement"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.type_ { @@ -2084,8 +2084,8 @@ impl DataPageHeader { } } -impl TSerializable for DataPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DataPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2139,7 +2139,7 @@ impl TSerializable for DataPageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DataPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2178,8 +2178,8 @@ impl IndexPageHeader { } } -impl TSerializable for IndexPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for IndexPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -2193,7 +2193,7 @@ impl TSerializable for IndexPageHeader { let ret = IndexPageHeader {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("IndexPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2229,8 +2229,8 @@ impl DictionaryPageHeader { } } -impl TSerializable for DictionaryPageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DictionaryPageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2270,7 +2270,7 @@ impl TSerializable for DictionaryPageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DictionaryPageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2337,8 +2337,8 @@ impl DataPageHeaderV2 { } } -impl TSerializable for DataPageHeaderV2 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for DataPageHeaderV2 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2412,7 +2412,7 @@ impl TSerializable for DataPageHeaderV2 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("DataPageHeaderV2"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("num_values", TType::I32, 1))?; @@ -2463,8 +2463,8 @@ impl SplitBlockAlgorithm { } } -impl TSerializable for SplitBlockAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SplitBlockAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -2478,7 +2478,7 @@ impl TSerializable for SplitBlockAlgorithm { let ret = SplitBlockAlgorithm {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SplitBlockAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2495,8 +2495,8 @@ pub enum BloomFilterAlgorithm { BLOCK(SplitBlockAlgorithm), } -impl TSerializable for BloomFilterAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2544,7 +2544,7 @@ impl TSerializable for BloomFilterAlgorithm { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2576,8 +2576,8 @@ impl XxHash { } } -impl TSerializable for XxHash { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for XxHash { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -2591,7 +2591,7 @@ impl TSerializable for XxHash { let ret = XxHash {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("XxHash"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2608,8 +2608,8 @@ pub enum BloomFilterHash { XXHASH(XxHash), } -impl TSerializable for BloomFilterHash { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterHash { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2657,7 +2657,7 @@ impl TSerializable for BloomFilterHash { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterHash"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2688,8 +2688,8 @@ impl Uncompressed { } } -impl TSerializable for Uncompressed { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for Uncompressed { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -2703,7 +2703,7 @@ impl TSerializable for Uncompressed { let ret = Uncompressed {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("Uncompressed"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -2720,8 +2720,8 @@ pub enum BloomFilterCompression { UNCOMPRESSED(Uncompressed), } -impl TSerializable for BloomFilterCompression { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterCompression { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -2769,7 +2769,7 @@ impl TSerializable for BloomFilterCompression { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterCompression"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -2814,8 +2814,8 @@ impl BloomFilterHeader { } } -impl TSerializable for BloomFilterHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for BloomFilterHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -2863,7 +2863,7 @@ impl TSerializable for BloomFilterHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("BloomFilterHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("numBytes", TType::I32, 1))?; @@ -2933,8 +2933,8 @@ impl PageHeader { } } -impl TSerializable for PageHeader { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageHeader { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3005,7 +3005,7 @@ impl TSerializable for PageHeader { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageHeader"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("type", TType::I32, 1))?; @@ -3067,8 +3067,8 @@ impl KeyValue { } } -impl TSerializable for KeyValue { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for KeyValue { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3101,7 +3101,7 @@ impl TSerializable for KeyValue { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("KeyValue"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("key", TType::String, 1))?; @@ -3143,8 +3143,8 @@ impl SortingColumn { } } -impl TSerializable for SortingColumn { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for SortingColumn { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3185,7 +3185,7 @@ impl TSerializable for SortingColumn { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("SortingColumn"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("column_idx", TType::I32, 1))?; @@ -3227,8 +3227,8 @@ impl PageEncodingStats { } } -impl TSerializable for PageEncodingStats { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageEncodingStats { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3269,7 +3269,7 @@ impl TSerializable for PageEncodingStats { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageEncodingStats"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("page_type", TType::I32, 1))?; @@ -3355,8 +3355,8 @@ impl ColumnMetaData { } } -impl TSerializable for ColumnMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -3498,7 +3498,7 @@ impl TSerializable for ColumnMetaData { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("type", TType::I32, 1))?; @@ -3595,8 +3595,8 @@ impl EncryptionWithFooterKey { } } -impl TSerializable for EncryptionWithFooterKey { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionWithFooterKey { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -3610,7 +3610,7 @@ impl TSerializable for EncryptionWithFooterKey { let ret = EncryptionWithFooterKey {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionWithFooterKey"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -3639,8 +3639,8 @@ impl EncryptionWithColumnKey { } } -impl TSerializable for EncryptionWithColumnKey { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionWithColumnKey { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -3679,7 +3679,7 @@ impl TSerializable for EncryptionWithColumnKey { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionWithColumnKey"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("path_in_schema", TType::List, 1))?; @@ -3709,8 +3709,8 @@ pub enum ColumnCryptoMetaData { ENCRYPTIONWITHCOLUMNKEY(EncryptionWithColumnKey), } -impl TSerializable for ColumnCryptoMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnCryptoMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -3765,7 +3765,7 @@ impl TSerializable for ColumnCryptoMetaData { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnCryptoMetaData"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -3832,8 +3832,8 @@ impl ColumnChunk { } } -impl TSerializable for ColumnChunk { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnChunk { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -3908,7 +3908,7 @@ impl TSerializable for ColumnChunk { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnChunk"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.file_path { @@ -4000,8 +4000,8 @@ impl RowGroup { } } -impl TSerializable for RowGroup { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for RowGroup { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option = None; @@ -4078,7 +4078,7 @@ impl TSerializable for RowGroup { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("RowGroup"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("columns", TType::List, 1))?; @@ -4138,8 +4138,8 @@ impl TypeDefinedOrder { } } -impl TSerializable for TypeDefinedOrder { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for TypeDefinedOrder { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; loop { let field_ident = i_prot.read_field_begin()?; @@ -4153,7 +4153,7 @@ impl TSerializable for TypeDefinedOrder { let ret = TypeDefinedOrder {}; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("TypeDefinedOrder"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_stop()?; @@ -4170,8 +4170,8 @@ pub enum ColumnOrder { TYPEORDER(TypeDefinedOrder), } -impl TSerializable for ColumnOrder { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnOrder { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -4219,7 +4219,7 @@ impl TSerializable for ColumnOrder { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnOrder"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -4260,8 +4260,8 @@ impl PageLocation { } } -impl TSerializable for PageLocation { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for PageLocation { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option = None; @@ -4302,7 +4302,7 @@ impl TSerializable for PageLocation { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("PageLocation"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("offset", TType::I64, 1))?; @@ -4338,8 +4338,8 @@ impl OffsetIndex { } } -impl TSerializable for OffsetIndex { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for OffsetIndex { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; loop { @@ -4372,7 +4372,7 @@ impl TSerializable for OffsetIndex { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("OffsetIndex"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("page_locations", TType::List, 1))?; @@ -4432,8 +4432,8 @@ impl ColumnIndex { } } -impl TSerializable for ColumnIndex { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for ColumnIndex { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option>> = None; @@ -4511,7 +4511,7 @@ impl TSerializable for ColumnIndex { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("ColumnIndex"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("null_pages", TType::List, 1))?; @@ -4577,8 +4577,8 @@ impl AesGcmV1 { } } -impl TSerializable for AesGcmV1 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for AesGcmV1 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -4616,7 +4616,7 @@ impl TSerializable for AesGcmV1 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("AesGcmV1"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.aad_prefix { @@ -4664,8 +4664,8 @@ impl AesGcmCtrV1 { } } -impl TSerializable for AesGcmCtrV1 { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for AesGcmCtrV1 { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option> = None; let mut f_2: Option> = None; @@ -4703,7 +4703,7 @@ impl TSerializable for AesGcmCtrV1 { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("AesGcmCtrV1"); o_prot.write_struct_begin(&struct_ident)?; if let Some(ref fld_var) = self.aad_prefix { @@ -4736,8 +4736,8 @@ pub enum EncryptionAlgorithm { AESGCMCTRV1(AesGcmCtrV1), } -impl TSerializable for EncryptionAlgorithm { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for EncryptionAlgorithm { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { let mut ret: Option = None; let mut received_field_count = 0; i_prot.read_struct_begin()?; @@ -4792,7 +4792,7 @@ impl TSerializable for EncryptionAlgorithm { Ok(ret.expect("return value should have been constructed")) } } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("EncryptionAlgorithm"); o_prot.write_struct_begin(&struct_ident)?; match *self { @@ -4879,8 +4879,8 @@ impl FileMetaData { } } -impl TSerializable for FileMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for FileMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -4982,7 +4982,7 @@ impl TSerializable for FileMetaData { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("FileMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("version", TType::I32, 1))?; @@ -5068,8 +5068,8 @@ impl FileCryptoMetaData { } } -impl TSerializable for FileCryptoMetaData { - fn read_from_in_protocol(i_prot: &mut dyn TInputProtocol) -> thrift::Result { +impl crate::thrift::TSerializable for FileCryptoMetaData { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result { i_prot.read_struct_begin()?; let mut f_1: Option = None; let mut f_2: Option> = None; @@ -5102,7 +5102,7 @@ impl TSerializable for FileCryptoMetaData { }; Ok(ret) } - fn write_to_out_protocol(&self, o_prot: &mut dyn TOutputProtocol) -> thrift::Result<()> { + fn write_to_out_protocol(&self, o_prot: &mut T) -> thrift::Result<()> { let struct_ident = TStructIdentifier::new("FileCryptoMetaData"); o_prot.write_struct_begin(&struct_ident)?; o_prot.write_field_begin(&TFieldIdentifier::new("encryption_algorithm", TType::Struct, 1))?; diff --git a/parquet/src/lib.rs b/parquet/src/lib.rs index 2371f8837bb0..f1612c90cc2a 100644 --- a/parquet/src/lib.rs +++ b/parquet/src/lib.rs @@ -88,3 +88,5 @@ pub mod bloom_filter; pub mod file; pub mod record; pub mod schema; + +pub mod thrift; diff --git a/parquet/src/thrift.rs b/parquet/src/thrift.rs new file mode 100644 index 000000000000..57f52edc6ef0 --- /dev/null +++ b/parquet/src/thrift.rs @@ -0,0 +1,284 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Custom thrift definitions + +use thrift::protocol::{ + TFieldIdentifier, TInputProtocol, TListIdentifier, TMapIdentifier, + TMessageIdentifier, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType, +}; + +/// Reads and writes the struct to Thrift protocols. +/// +/// Unlike [`thrift::protocol::TSerializable`] this uses generics instead of trait objects +pub trait TSerializable: Sized { + fn read_from_in_protocol(i_prot: &mut T) -> thrift::Result; + fn write_to_out_protocol( + &self, + o_prot: &mut T, + ) -> thrift::Result<()>; +} + +/// A more performant implementation of [`TCompactInputProtocol`] that reads a slice +/// +/// [`TCompactInputProtocol`]: thrift::protocol::TCompactInputProtocol +pub(crate) struct TCompactSliceInputProtocol<'a> { + buf: &'a [u8], + // Identifier of the last field deserialized for a struct. + last_read_field_id: i16, + // Stack of the last read field ids (a new entry is added each time a nested struct is read). + read_field_id_stack: Vec, + // Boolean value for a field. + // Saved because boolean fields and their value are encoded in a single byte, + // and reading the field only occurs after the field id is read. + pending_read_bool_value: Option, +} + +impl<'a> TCompactSliceInputProtocol<'a> { + pub fn new(buf: &'a [u8]) -> Self { + Self { + buf, + last_read_field_id: 0, + read_field_id_stack: Vec::with_capacity(16), + pending_read_bool_value: None, + } + } + + pub fn as_slice(&self) -> &'a [u8] { + self.buf + } + + fn read_vlq(&mut self) -> thrift::Result { + let mut in_progress = 0; + let mut shift = 0; + loop { + let byte = self.read_byte()?; + in_progress |= ((byte & 0x7F) as u64) << shift; + shift += 7; + if byte & 0x80 == 0 { + return Ok(in_progress); + } + } + } + + fn read_zig_zag(&mut self) -> thrift::Result { + let val = self.read_vlq()?; + Ok((val >> 1) as i64 ^ -((val & 1) as i64)) + } + + fn read_list_set_begin(&mut self) -> thrift::Result<(TType, i32)> { + let header = self.read_byte()?; + let element_type = collection_u8_to_type(header & 0x0F)?; + + let possible_element_count = (header & 0xF0) >> 4; + let element_count = if possible_element_count != 15 { + // high bits set high if count and type encoded separately + possible_element_count as i32 + } else { + self.read_vlq()? as _ + }; + + Ok((element_type, element_count)) + } +} + +impl<'a> TInputProtocol for TCompactSliceInputProtocol<'a> { + fn read_message_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_message_end(&mut self) -> thrift::Result<()> { + unimplemented!() + } + + fn read_struct_begin(&mut self) -> thrift::Result> { + self.read_field_id_stack.push(self.last_read_field_id); + self.last_read_field_id = 0; + Ok(None) + } + + fn read_struct_end(&mut self) -> thrift::Result<()> { + self.last_read_field_id = self + .read_field_id_stack + .pop() + .expect("should have previous field ids"); + Ok(()) + } + + fn read_field_begin(&mut self) -> thrift::Result { + // we can read at least one byte, which is: + // - the type + // - the field delta and the type + let field_type = self.read_byte()?; + let field_delta = (field_type & 0xF0) >> 4; + let field_type = match field_type & 0x0F { + 0x01 => { + self.pending_read_bool_value = Some(true); + Ok(TType::Bool) + } + 0x02 => { + self.pending_read_bool_value = Some(false); + Ok(TType::Bool) + } + ttu8 => u8_to_type(ttu8), + }?; + + match field_type { + TType::Stop => Ok( + TFieldIdentifier::new::, String, Option>( + None, + TType::Stop, + None, + ), + ), + _ => { + if field_delta != 0 { + self.last_read_field_id += field_delta as i16; + } else { + self.last_read_field_id = self.read_i16()?; + }; + + Ok(TFieldIdentifier { + name: None, + field_type, + id: Some(self.last_read_field_id), + }) + } + } + } + + fn read_field_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + fn read_bool(&mut self) -> thrift::Result { + match self.pending_read_bool_value.take() { + Some(b) => Ok(b), + None => { + let b = self.read_byte()?; + match b { + 0x01 => Ok(true), + 0x02 => Ok(false), + unkn => Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into bool", unkn), + })), + } + } + } + } + + fn read_bytes(&mut self) -> thrift::Result> { + let len = self.read_vlq()? as usize; + let ret = self.buf.get(..len).ok_or_else(eof_error)?.to_vec(); + self.buf = &self.buf[len..]; + Ok(ret) + } + + fn read_i8(&mut self) -> thrift::Result { + Ok(self.read_byte()? as _) + } + + fn read_i16(&mut self) -> thrift::Result { + Ok(self.read_zig_zag()? as _) + } + + fn read_i32(&mut self) -> thrift::Result { + Ok(self.read_zig_zag()? as _) + } + + fn read_i64(&mut self) -> thrift::Result { + self.read_zig_zag() + } + + fn read_double(&mut self) -> thrift::Result { + let slice = (self.buf[..8]).try_into().unwrap(); + self.buf = &self.buf[8..]; + Ok(f64::from_le_bytes(slice)) + } + + fn read_string(&mut self) -> thrift::Result { + let bytes = self.read_bytes()?; + String::from_utf8(bytes).map_err(From::from) + } + + fn read_list_begin(&mut self) -> thrift::Result { + let (element_type, element_count) = self.read_list_set_begin()?; + Ok(TListIdentifier::new(element_type, element_count)) + } + + fn read_list_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + fn read_set_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_set_end(&mut self) -> thrift::Result<()> { + unimplemented!() + } + + fn read_map_begin(&mut self) -> thrift::Result { + unimplemented!() + } + + fn read_map_end(&mut self) -> thrift::Result<()> { + Ok(()) + } + + #[inline] + fn read_byte(&mut self) -> thrift::Result { + let ret = *self.buf.first().ok_or_else(eof_error)?; + self.buf = &self.buf[1..]; + Ok(ret) + } +} + +fn collection_u8_to_type(b: u8) -> thrift::Result { + match b { + 0x01 => Ok(TType::Bool), + o => u8_to_type(o), + } +} + +fn u8_to_type(b: u8) -> thrift::Result { + match b { + 0x00 => Ok(TType::Stop), + 0x03 => Ok(TType::I08), // equivalent to TType::Byte + 0x04 => Ok(TType::I16), + 0x05 => Ok(TType::I32), + 0x06 => Ok(TType::I64), + 0x07 => Ok(TType::Double), + 0x08 => Ok(TType::String), + 0x09 => Ok(TType::List), + 0x0A => Ok(TType::Set), + 0x0B => Ok(TType::Map), + 0x0C => Ok(TType::Struct), + unkn => Err(thrift::Error::Protocol(thrift::ProtocolError { + kind: thrift::ProtocolErrorKind::InvalidData, + message: format!("cannot convert {} into TType", unkn), + })), + } +} + +fn eof_error() -> thrift::Error { + thrift::Error::Transport(thrift::TransportError { + kind: thrift::TransportErrorKind::EndOfFile, + message: "Unexpected EOF".to_string(), + }) +}