From d9aaa437ca4ebf5a3500c865272243612862c7d4 Mon Sep 17 00:00:00 2001 From: Joseph Rance <56409230+Joseph-Rance@users.noreply.github.com> Date: Mon, 30 Oct 2023 11:40:34 +0000 Subject: [PATCH] Add `RecordReader` trait and proc macro to implement it for a struct (#4773) * add and implement RecordReader trait for rust structs * Fix typo in comment * run cargo fmt * partially solve issues raised in review * remove references * change interface to use vectors * change interface to use vectors in as well * update comments * remove intitialisation requirement * prevent conflicts with existing default implementation * update documentation * run cargo fmt * change writer back to slice * change 'Handle' back to 'Derive' for RecordWriter macro in readme --------- Co-authored-by: joseph rance --- parquet/src/record/mod.rs | 2 + parquet/src/record/record_reader.rs | 30 +++ parquet/src/record/record_writer.rs | 4 + parquet_derive/README.md | 51 ++++- parquet_derive/src/lib.rs | 88 +++++++- parquet_derive/src/parquet_field.rs | 338 ++++++++++++++++++++++++++-- parquet_derive_test/src/lib.rs | 70 +++++- 7 files changed, 553 insertions(+), 30 deletions(-) create mode 100644 parquet/src/record/record_reader.rs diff --git a/parquet/src/record/mod.rs b/parquet/src/record/mod.rs index 771d8058c9c1..f40e91418da1 100644 --- a/parquet/src/record/mod.rs +++ b/parquet/src/record/mod.rs @@ -19,6 +19,7 @@ mod api; pub mod reader; +mod record_reader; mod record_writer; mod triplet; @@ -26,5 +27,6 @@ pub use self::{ api::{ Field, List, ListAccessor, Map, MapAccessor, Row, RowAccessor, RowColumnIter, RowFormatter, }, + record_reader::RecordReader, record_writer::RecordWriter, }; diff --git a/parquet/src/record/record_reader.rs b/parquet/src/record/record_reader.rs new file mode 100644 index 000000000000..bcfeb95dcdf4 --- /dev/null +++ b/parquet/src/record/record_reader.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::super::errors::ParquetError; +use super::super::file::reader::RowGroupReader; + +/// read up to `max_records` records from `row_group_reader` into `self` +/// The type parameter `T` is used to work around the rust orphan rule +/// when implementing on types such as `Vec`. +pub trait RecordReader { + fn read_from_row_group( + &mut self, + row_group_reader: &mut dyn RowGroupReader, + num_records: usize, + ) -> Result<(), ParquetError>; +} diff --git a/parquet/src/record/record_writer.rs b/parquet/src/record/record_writer.rs index 62099051f513..0b2b95ef7dea 100644 --- a/parquet/src/record/record_writer.rs +++ b/parquet/src/record/record_writer.rs @@ -20,6 +20,10 @@ use crate::schema::types::TypePtr; use super::super::errors::ParquetError; use super::super::file::writer::SerializedRowGroupWriter; +/// `write_to_row_group` writes from `self` into `row_group_writer` +/// `schema` builds the schema used by `row_group_writer` +/// The type parameter `T` is used to work around the rust orphan rule +/// when implementing on types such as `&[T]`. pub trait RecordWriter { fn write_to_row_group( &self, diff --git a/parquet_derive/README.md b/parquet_derive/README.md index b20721079c2d..c267a92430e0 100644 --- a/parquet_derive/README.md +++ b/parquet_derive/README.md @@ -19,9 +19,9 @@ # Parquet Derive -A crate for deriving `RecordWriter` for arbitrary, _simple_ structs. This does not generate writers for arbitrarily nested -structures. It only works for primitives and a few generic structures and -various levels of reference. Please see features checklist for what is currently +A crate for deriving `RecordWriter` and `RecordReader` for arbitrary, _simple_ structs. This does not +generate readers or writers for arbitrarily nested structures. It only works for primitives and a few +generic structures and various levels of reference. Please see features checklist for what is currently supported. Derive also has some support for the chrono time library. You must must enable the `chrono` feature to get this support. @@ -77,16 +77,55 @@ writer.close_row_group(row_group).unwrap(); writer.close().unwrap(); ``` +Example usage of deriving a `RecordReader` for your struct: + +```rust +use parquet::file::{serialized_reader::SerializedFileReader, reader::FileReader}; +use parquet_derive::ParquetRecordReader; + +#[derive(ParquetRecordReader)] +struct ACompleteRecord { + pub a_bool: bool, + pub a_string: String, + pub i16: i16, + pub i32: i32, + pub u64: u64, + pub isize: isize, + pub float: f32, + pub double: f64, + pub now: chrono::NaiveDateTime, + pub byte_vec: Vec, +} + +// Initialize your parquet file +let reader = SerializedFileReader::new(file).unwrap(); +let mut row_group = reader.get_row_group(0).unwrap(); + +// create your records vector to read into +let mut chunks: Vec = Vec::new(); + +// The derived `RecordReader` takes over here +chunks.read_from_row_group(&mut *row_group, 1).unwrap(); +``` + ## Features - [x] Support writing `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec` - [ ] Support writing dictionaries - [x] Support writing logical types like timestamp -- [x] Derive definition_levels for `Option` -- [ ] Derive definition levels for nested structures +- [x] Derive definition_levels for `Option` for writing +- [ ] Derive definition levels for nested structures for writing - [ ] Derive writing tuple struct - [ ] Derive writing `tuple` container types +- [x] Support reading `String`, `&str`, `bool`, `i32`, `f32`, `f64`, `Vec` +- [ ] Support reading/writing dictionaries +- [x] Support reading/writing logical types like timestamp +- [ ] Handle definition_levels for `Option` for reading +- [ ] Handle definition levels for nested structures for reading +- [ ] Derive reading/writing tuple struct +- [ ] Derive reading/writing `tuple` container types + ## Requirements - Same as `parquet-rs` @@ -103,4 +142,4 @@ To compile and view in the browser, run `cargo doc --no-deps --open`. ## License -Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0. +Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0. \ No newline at end of file diff --git a/parquet_derive/src/lib.rs b/parquet_derive/src/lib.rs index c6641cd8091d..671a46db0f31 100644 --- a/parquet_derive/src/lib.rs +++ b/parquet_derive/src/lib.rs @@ -44,7 +44,7 @@ mod parquet_field; /// use parquet::file::writer::SerializedFileWriter; /// /// use std::sync::Arc; -// +/// /// #[derive(ParquetRecordWriter)] /// struct ACompleteRecord<'a> { /// pub a_bool: bool, @@ -137,3 +137,89 @@ pub fn parquet_record_writer(input: proc_macro::TokenStream) -> proc_macro::Toke } }).into() } + +/// Derive flat, simple RecordReader implementations. Works by parsing +/// a struct tagged with `#[derive(ParquetRecordReader)]` and emitting +/// the correct writing code for each field of the struct. Column readers +/// are generated in the order they are defined. +/// +/// It is up to the programmer to keep the order of the struct +/// fields lined up with the schema. +/// +/// Example: +/// +/// ```ignore +/// use parquet::file::{serialized_reader::SerializedFileReader, reader::FileReader}; +/// use parquet_derive::{ParquetRecordReader}; +/// +/// #[derive(ParquetRecordReader)] +/// struct ACompleteRecord { +/// pub a_bool: bool, +/// pub a_string: String, +/// } +/// +/// pub fn read_some_records() -> Vec { +/// let mut samples: Vec = Vec::new(); +/// +/// let reader = SerializedFileReader::new(file).unwrap(); +/// let mut row_group = reader.get_row_group(0).unwrap(); +/// samples.read_from_row_group(&mut *row_group, 1).unwrap(); +/// samples +/// } +/// ``` +/// +#[proc_macro_derive(ParquetRecordReader)] +pub fn parquet_record_reader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input: DeriveInput = parse_macro_input!(input as DeriveInput); + let fields = match input.data { + Data::Struct(DataStruct { fields, .. }) => fields, + Data::Enum(_) => unimplemented!("Enum currently is not supported"), + Data::Union(_) => unimplemented!("Union currently is not supported"), + }; + + let field_infos: Vec<_> = fields.iter().map(parquet_field::Field::from).collect(); + let field_names: Vec<_> = fields.iter().map(|f| f.ident.clone()).collect(); + let reader_snippets: Vec = + field_infos.iter().map(|x| x.reader_snippet()).collect(); + let i: Vec<_> = (0..reader_snippets.len()).collect(); + + let derived_for = input.ident; + let generics = input.generics; + + (quote! { + + impl #generics ::parquet::record::RecordReader<#derived_for #generics> for Vec<#derived_for #generics> { + fn read_from_row_group( + &mut self, + row_group_reader: &mut dyn ::parquet::file::reader::RowGroupReader, + num_records: usize, + ) -> Result<(), ::parquet::errors::ParquetError> { + use ::parquet::column::reader::ColumnReader; + + let mut row_group_reader = row_group_reader; + + for _ in 0..num_records { + self.push(#derived_for { + #( + #field_names: Default::default() + ),* + }) + } + + let records = self; // Used by all the reader snippets to be more clear + + #( + { + if let Ok(mut column_reader) = row_group_reader.get_column_reader(#i) { + #reader_snippets + } else { + return Err(::parquet::errors::ParquetError::General("Failed to get next column".into())) + } + } + );* + + Ok(()) + } + } + }).into() +} diff --git a/parquet_derive/src/parquet_field.rs b/parquet_derive/src/parquet_field.rs index e629bfe757ab..0ac95c2864e5 100644 --- a/parquet_derive/src/parquet_field.rs +++ b/parquet_derive/src/parquet_field.rs @@ -219,6 +219,72 @@ impl Field { } } + /// Takes the parsed field of the struct and emits a valid + /// column reader snippet. Should match exactly what you + /// would write by hand. + /// + /// Can only generate writers for basic structs, for example: + /// + /// struct Record { + /// a_bool: bool + /// } + /// + /// but not + /// + /// struct UnsupportedNestedRecord { + /// a_property: bool, + /// nested_record: Record + /// } + /// + /// because this parsing logic is not sophisticated enough for definition + /// levels beyond 2. + /// + /// `Option` types and references not supported + pub fn reader_snippet(&self) -> proc_macro2::TokenStream { + let ident = &self.ident; + let column_reader = self.ty.column_reader(); + let parquet_type = self.ty.physical_type_as_rust(); + + // generate the code to read the column into a vector `vals` + let write_batch_expr = quote! { + let mut vals_vec = Vec::new(); + vals_vec.resize(num_records, Default::default()); + let mut vals: &mut [#parquet_type] = vals_vec.as_mut_slice(); + if let #column_reader(mut typed) = column_reader { + typed.read_records(num_records, None, None, vals)?; + } else { + panic!("Schema and struct disagree on type for {}", stringify!{#ident}); + } + }; + + // generate the code to convert each element of `vals` to the correct type and then write + // it to its field in the corresponding struct + let vals_writer = match &self.ty { + Type::TypePath(_) => self.copied_direct_fields(), + Type::Reference(_, ref first_type) => match **first_type { + Type::TypePath(_) => self.copied_direct_fields(), + Type::Slice(ref second_type) => match **second_type { + Type::TypePath(_) => self.copied_direct_fields(), + ref f => unimplemented!("Unsupported: {:#?}", f), + }, + ref f => unimplemented!("Unsupported: {:#?}", f), + }, + Type::Vec(ref first_type) => match **first_type { + Type::TypePath(_) => self.copied_direct_fields(), + ref f => unimplemented!("Unsupported: {:#?}", f), + }, + f => unimplemented!("Unsupported: {:#?}", f), + }; + + quote! { + { + #write_batch_expr + + #vals_writer + } + } + } + pub fn parquet_type(&self) -> proc_macro2::TokenStream { // TODO: Support group types // TODO: Add length if dealing with fixedlenbinary @@ -319,27 +385,31 @@ impl Field { } } + // generates code to read `field_name` from each record into a vector `vals` fn copied_direct_vals(&self) -> proc_macro2::TokenStream { let field_name = &self.ident; - let is_a_byte_buf = self.is_a_byte_buf; - let is_a_timestamp = self.third_party_type == Some(ThirdPartyType::ChronoNaiveDateTime); - let is_a_date = self.third_party_type == Some(ThirdPartyType::ChronoNaiveDate); - let is_a_uuid = self.third_party_type == Some(ThirdPartyType::Uuid); - let access = if is_a_timestamp { - quote! { rec.#field_name.timestamp_millis() } - } else if is_a_date { - quote! { rec.#field_name.signed_duration_since(::chrono::NaiveDate::from_ymd(1970, 1, 1)).num_days() as i32 } - } else if is_a_uuid { - quote! { (&rec.#field_name.to_string()[..]).into() } - } else if is_a_byte_buf { - quote! { (&rec.#field_name[..]).into() } - } else { - // Type might need converting to a physical type - match self.ty.physical_type() { - parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 }, - parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 }, - _ => quote! { rec.#field_name }, + let access = match self.third_party_type { + Some(ThirdPartyType::ChronoNaiveDateTime) => { + quote! { rec.#field_name.timestamp_millis() } + } + Some(ThirdPartyType::ChronoNaiveDate) => { + quote! { rec.#field_name.signed_duration_since(::chrono::NaiveDate::from_ymd(1970, 1, 1)).num_days() as i32 } + } + Some(ThirdPartyType::Uuid) => { + quote! { (&rec.#field_name.to_string()[..]).into() } + } + _ => { + if self.is_a_byte_buf { + quote! { (&rec.#field_name[..]).into() } + } else { + // Type might need converting to a physical type + match self.ty.physical_type() { + parquet::basic::Type::INT32 => quote! { rec.#field_name as i32 }, + parquet::basic::Type::INT64 => quote! { rec.#field_name as i64 }, + _ => quote! { rec.#field_name }, + } + } } }; @@ -348,6 +418,48 @@ impl Field { } } + // generates code to read a vector `records` into `field_name` for each record + fn copied_direct_fields(&self) -> proc_macro2::TokenStream { + let field_name = &self.ident; + + let value = match self.third_party_type { + Some(ThirdPartyType::ChronoNaiveDateTime) => { + quote! { ::chrono::naive::NaiveDateTime::from_timestamp_millis(vals[i]).unwrap() } + } + Some(ThirdPartyType::ChronoNaiveDate) => { + quote! { + ::chrono::naive::NaiveDate::from_num_days_from_ce_opt(vals[i] + + ((::chrono::naive::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap() + .signed_duration_since( + ::chrono::naive::NaiveDate::from_ymd_opt(0, 12, 31).unwrap() + ) + ).num_days()) as i32).unwrap() + } + } + Some(ThirdPartyType::Uuid) => { + quote! { ::uuid::Uuid::parse_str(vals[i].data().convert()).unwrap() } + } + _ => match &self.ty { + Type::TypePath(_) => match self.ty.last_part().as_str() { + "String" => quote! { String::from(std::str::from_utf8(vals[i].data()) + .expect("invalid UTF-8 sequence")) }, + t => { + let s: proc_macro2::TokenStream = t.parse().unwrap(); + quote! { vals[i] as #s } + } + }, + Type::Vec(_) => quote! { vals[i].data().to_vec() }, + f => unimplemented!("Unsupported: {:#?}", f), + }, + }; + + quote! { + for (i, r) in &mut records[..num_records].iter_mut().enumerate() { + r.#field_name = #value; + } + } + } + fn optional_definition_levels(&self) -> proc_macro2::TokenStream { let field_name = &self.ident; @@ -396,6 +508,29 @@ impl Type { } } + /// Takes a rust type and returns the appropriate + /// parquet-rs column reader + fn column_reader(&self) -> syn::TypePath { + use parquet::basic::Type as BasicType; + + match self.physical_type() { + BasicType::BOOLEAN => { + syn::parse_quote!(ColumnReader::BoolColumnReader) + } + BasicType::INT32 => syn::parse_quote!(ColumnReader::Int32ColumnReader), + BasicType::INT64 => syn::parse_quote!(ColumnReader::Int64ColumnReader), + BasicType::INT96 => syn::parse_quote!(ColumnReader::Int96ColumnReader), + BasicType::FLOAT => syn::parse_quote!(ColumnReader::FloatColumnReader), + BasicType::DOUBLE => syn::parse_quote!(ColumnReader::DoubleColumnReader), + BasicType::BYTE_ARRAY => { + syn::parse_quote!(ColumnReader::ByteArrayColumnReader) + } + BasicType::FIXED_LEN_BYTE_ARRAY => { + syn::parse_quote!(ColumnReader::FixedLenByteArrayColumnReader) + } + } + } + /// Helper to simplify a nested field definition to its leaf type /// /// Ex: @@ -515,6 +650,23 @@ impl Type { } } + fn physical_type_as_rust(&self) -> proc_macro2::TokenStream { + use parquet::basic::Type as BasicType; + + match self.physical_type() { + BasicType::BOOLEAN => quote! { bool }, + BasicType::INT32 => quote! { i32 }, + BasicType::INT64 => quote! { i64 }, + BasicType::INT96 => unimplemented!("96-bit int currently is not supported"), + BasicType::FLOAT => quote! { f32 }, + BasicType::DOUBLE => quote! { f64 }, + BasicType::BYTE_ARRAY => quote! { ::parquet::data_type::ByteArray }, + BasicType::FIXED_LEN_BYTE_ARRAY => { + quote! { ::parquet::data_type::FixedLenByteArray } + } + } + } + fn logical_type(&self) -> proc_macro2::TokenStream { let last_part = self.last_part(); let leaf_type = self.leaf_type_recursive(); @@ -713,6 +865,39 @@ mod test { ) } + #[test] + fn test_generating_a_simple_reader_snippet() { + let snippet: proc_macro2::TokenStream = quote! { + struct ABoringStruct { + counter: usize, + } + }; + + let fields = extract_fields(snippet); + let counter = Field::from(&fields[0]); + + let snippet = counter.reader_snippet().to_string(); + assert_eq!( + snippet, + (quote! { + { + let mut vals_vec = Vec::new(); + vals_vec.resize(num_records, Default::default()); + let mut vals: &mut[i64] = vals_vec.as_mut_slice(); + if let ColumnReader::Int64ColumnReader(mut typed) = column_reader { + typed.read_records(num_records, None, None, vals)?; + } else { + panic!("Schema and struct disagree on type for {}", stringify!{ counter }); + } + for (i, r) in &mut records[..num_records].iter_mut().enumerate() { + r.counter = vals[i] as usize; + } + } + }) + .to_string() + ) + } + #[test] fn test_optional_to_writer_snippet() { let struct_def: proc_macro2::TokenStream = quote! { @@ -822,6 +1007,32 @@ mod test { ); } + #[test] + fn test_converting_to_column_reader_type() { + let snippet: proc_macro2::TokenStream = quote! { + struct ABasicStruct { + yes_no: bool, + name: String, + } + }; + + let fields = extract_fields(snippet); + let processed: Vec<_> = fields.iter().map(Field::from).collect(); + + let column_readers: Vec<_> = processed + .iter() + .map(|field| field.ty.column_reader()) + .collect(); + + assert_eq!( + column_readers, + vec![ + syn::parse_quote!(ColumnReader::BoolColumnReader), + syn::parse_quote!(ColumnReader::ByteArrayColumnReader) + ] + ); + } + #[test] fn convert_basic_struct() { let snippet: proc_macro2::TokenStream = quote! { @@ -995,7 +1206,7 @@ mod test { } #[test] - fn test_chrono_timestamp_millis() { + fn test_chrono_timestamp_millis_write() { let snippet: proc_macro2::TokenStream = quote! { struct ATimestampStruct { henceforth: chrono::NaiveDateTime, @@ -1038,7 +1249,34 @@ mod test { } #[test] - fn test_chrono_date() { + fn test_chrono_timestamp_millis_read() { + let snippet: proc_macro2::TokenStream = quote! { + struct ATimestampStruct { + henceforth: chrono::NaiveDateTime, + } + }; + + let fields = extract_fields(snippet); + let when = Field::from(&fields[0]); + assert_eq!(when.reader_snippet().to_string(),(quote!{ + { + let mut vals_vec = Vec::new(); + vals_vec.resize(num_records, Default::default()); + let mut vals: &mut[i64] = vals_vec.as_mut_slice(); + if let ColumnReader::Int64ColumnReader(mut typed) = column_reader { + typed.read_records(num_records, None, None, vals)?; + } else { + panic!("Schema and struct disagree on type for {}", stringify!{ henceforth }); + } + for (i, r) in &mut records[..num_records].iter_mut().enumerate() { + r.henceforth = ::chrono::naive::NaiveDateTime::from_timestamp_millis(vals[i]).unwrap(); + } + } + }).to_string()); + } + + #[test] + fn test_chrono_date_write() { let snippet: proc_macro2::TokenStream = quote! { struct ATimestampStruct { henceforth: chrono::NaiveDate, @@ -1081,7 +1319,38 @@ mod test { } #[test] - fn test_uuid() { + fn test_chrono_date_read() { + let snippet: proc_macro2::TokenStream = quote! { + struct ATimestampStruct { + henceforth: chrono::NaiveDate, + } + }; + + let fields = extract_fields(snippet); + let when = Field::from(&fields[0]); + assert_eq!(when.reader_snippet().to_string(),(quote!{ + { + let mut vals_vec = Vec::new(); + vals_vec.resize(num_records, Default::default()); + let mut vals: &mut [i32] = vals_vec.as_mut_slice(); + if let ColumnReader::Int32ColumnReader(mut typed) = column_reader { + typed.read_records(num_records, None, None, vals)?; + } else { + panic!("Schema and struct disagree on type for {}", stringify!{ henceforth }); + } + for (i, r) in &mut records[..num_records].iter_mut().enumerate() { + r.henceforth = ::chrono::naive::NaiveDate::from_num_days_from_ce_opt(vals[i] + + ((::chrono::naive::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap() + .signed_duration_since( + ::chrono::naive::NaiveDate::from_ymd_opt(0, 12, 31).unwrap() + )).num_days()) as i32).unwrap(); + } + } + }).to_string()); + } + + #[test] + fn test_uuid_write() { let snippet: proc_macro2::TokenStream = quote! { struct AUuidStruct { unique_id: uuid::Uuid, @@ -1123,6 +1392,33 @@ mod test { }).to_string()); } + #[test] + fn test_uuid_read() { + let snippet: proc_macro2::TokenStream = quote! { + struct AUuidStruct { + unique_id: uuid::Uuid, + } + }; + + let fields = extract_fields(snippet); + let when = Field::from(&fields[0]); + assert_eq!(when.reader_snippet().to_string(),(quote!{ + { + let mut vals_vec = Vec::new(); + vals_vec.resize(num_records, Default::default()); + let mut vals: &mut [::parquet::data_type::ByteArray] = vals_vec.as_mut_slice(); + if let ColumnReader::ByteArrayColumnReader(mut typed) = column_reader { + typed.read_records(num_records, None, None, vals)?; + } else { + panic!("Schema and struct disagree on type for {}", stringify!{ unique_id }); + } + for (i, r) in &mut records[..num_records].iter_mut().enumerate() { + r.unique_id = ::uuid::Uuid::parse_str(vals[i].data().convert()).unwrap(); + } + } + }).to_string()); + } + #[test] fn test_converted_type() { let snippet: proc_macro2::TokenStream = quote! { diff --git a/parquet_derive_test/src/lib.rs b/parquet_derive_test/src/lib.rs index d377fb0a62af..a8b631ecc024 100644 --- a/parquet_derive_test/src/lib.rs +++ b/parquet_derive_test/src/lib.rs @@ -17,7 +17,7 @@ #![allow(clippy::approx_constant)] -use parquet_derive::ParquetRecordWriter; +use parquet_derive::{ParquetRecordReader, ParquetRecordWriter}; #[derive(ParquetRecordWriter)] struct ACompleteRecord<'a> { @@ -49,6 +49,21 @@ struct ACompleteRecord<'a> { pub borrowed_maybe_borrowed_byte_vec: &'a Option<&'a [u8]>, } +#[derive(PartialEq, ParquetRecordWriter, ParquetRecordReader, Debug)] +struct APartiallyCompleteRecord { + pub bool: bool, + pub string: String, + pub i16: i16, + pub i32: i32, + pub u64: u64, + pub isize: isize, + pub float: f32, + pub double: f64, + pub now: chrono::NaiveDateTime, + pub date: chrono::NaiveDate, + pub byte_vec: Vec, +} + #[cfg(test)] mod tests { use super::*; @@ -56,7 +71,8 @@ mod tests { use std::{env, fs, io::Write, sync::Arc}; use parquet::{ - file::writer::SerializedFileWriter, record::RecordWriter, + file::writer::SerializedFileWriter, + record::{RecordReader, RecordWriter}, schema::parser::parse_message_type, }; @@ -147,6 +163,56 @@ mod tests { writer.close().unwrap(); } + #[test] + fn test_parquet_derive_read_write_combined() { + let file = get_temp_file("test_parquet_derive_combined", &[]); + + let mut drs: Vec = vec![APartiallyCompleteRecord { + bool: true, + string: "a string".into(), + i16: -45, + i32: 456, + u64: 4563424, + isize: -365, + float: 3.5, + double: std::f64::NAN, + now: chrono::Utc::now().naive_local(), + date: chrono::naive::NaiveDate::from_ymd_opt(2015, 3, 14).unwrap(), + byte_vec: vec![0x65, 0x66, 0x67], + }]; + + let mut out: Vec = Vec::new(); + + use parquet::file::{reader::FileReader, serialized_reader::SerializedFileReader}; + + let generated_schema = drs.as_slice().schema().unwrap(); + + let props = Default::default(); + let mut writer = + SerializedFileWriter::new(file.try_clone().unwrap(), generated_schema, props).unwrap(); + + let mut row_group = writer.next_row_group().unwrap(); + drs.as_slice().write_to_row_group(&mut row_group).unwrap(); + row_group.close().unwrap(); + writer.close().unwrap(); + + let reader = SerializedFileReader::new(file).unwrap(); + + let mut row_group = reader.get_row_group(0).unwrap(); + out.read_from_row_group(&mut *row_group, 1).unwrap(); + + // correct for rounding error when writing milliseconds + drs[0].now = + chrono::naive::NaiveDateTime::from_timestamp_millis(drs[0].now.timestamp_millis()) + .unwrap(); + + assert!(out[0].double.is_nan()); // these three lines are necessary because NAN != NAN + out[0].double = 0.; + drs[0].double = 0.; + + assert_eq!(drs[0], out[0]); + } + /// Returns file handle for a temp file in 'target' directory with a provided content pub fn get_temp_file(file_name: &str, content: &[u8]) -> fs::File { // build tmp path to a file in "target/debug/testdata"