From eba5f150675ad089cbd121da11ff43adc0f09a11 Mon Sep 17 00:00:00 2001 From: Yanxin Xiang Date: Sat, 20 Apr 2024 17:41:34 -0500 Subject: [PATCH] Provide Arrow Schema Hint to Parquet Reader --- arrow-schema/src/datatype.rs | 61 ++++++++++ parquet/src/arrow/array_reader/builder.rs | 9 +- parquet/src/arrow/arrow_reader/mod.rs | 129 +++++++++++++++++++++- 3 files changed, 192 insertions(+), 7 deletions(-) diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 449d363db671..1eb27f6e8596 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -688,6 +688,32 @@ pub const DECIMAL256_MAX_SCALE: i8 = 76; /// values pub const DECIMAL_DEFAULT_SCALE: i8 = 10; +// used for we should decide what two different +// Datatype could be reinterpret +pub fn can_reinterpret(from_type: &DataType, to_type: &DataType) -> bool { + use DataType::*; + if from_type == to_type { + return true; + } + // Handle integers and unsigned integers separately + let is_compatible_integer = matches!( + (from_type, to_type), + (Int8, UInt8) + | (UInt8, Int8) + | (Int16, UInt16) + | (UInt16, Int16) + | (Int32, UInt32) + | (UInt32, Int32) + | (Int64, UInt64) + | (UInt64, Int64) + ); + + let is_compatible_timestamp = matches!((from_type, to_type), + (Timestamp(unit1, _), Timestamp(unit2, _)) if unit1 == unit2 + ); + + is_compatible_integer || is_compatible_timestamp +} #[cfg(test)] mod tests { use super::*; @@ -985,4 +1011,39 @@ mod tests { UnionMode::Dense, ); } + + #[test] + fn test_could_reinterpret() { + // Testing integer and unsigned integer reinterpretation + assert!(can_reinterpret(&DataType::Int32, &DataType::UInt32),); + assert!(can_reinterpret(&DataType::UInt32, &DataType::Int32),); + assert!(!can_reinterpret(&DataType::Int32, &DataType::Int64),); + + // Testing timestamp reinterpretation with same time units and timezones + let tz_utc = Some(Arc::from("UTC")); + let tz_est = Some(Arc::from("+07:00")); + assert!(can_reinterpret( + &DataType::Timestamp(TimeUnit::Second, tz_utc.clone()), + &DataType::Timestamp(TimeUnit::Second, tz_est.clone()) + ),); + assert!(can_reinterpret( + &DataType::Timestamp(TimeUnit::Microsecond, tz_utc.clone()), + &DataType::Timestamp(TimeUnit::Microsecond, tz_est.clone()) + ),); + assert!(!can_reinterpret( + &DataType::Timestamp(TimeUnit::Second, tz_utc.clone()), + &DataType::Timestamp(TimeUnit::Millisecond, tz_utc.clone()) + ),); + assert!(can_reinterpret( + &DataType::Timestamp(TimeUnit::Second, tz_utc), + &DataType::Timestamp(TimeUnit::Second, tz_est) + ),); + + // Testing negative cases for mixed types + assert!(!can_reinterpret(&DataType::Int32, &DataType::Float32),); + assert!(!can_reinterpret( + &DataType::Timestamp(TimeUnit::Second, None), + &DataType::Int64 + ),); + } } diff --git a/parquet/src/arrow/array_reader/builder.rs b/parquet/src/arrow/array_reader/builder.rs index 958594c93232..f455851cefad 100644 --- a/parquet/src/arrow/array_reader/builder.rs +++ b/parquet/src/arrow/array_reader/builder.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use arrow_schema::{DataType, Fields, SchemaBuilder}; +use arrow_schema::{can_reinterpret, DataType, Fields, SchemaBuilder}; use crate::arrow::array_reader::byte_array::make_byte_view_array_reader; use crate::arrow::array_reader::empty_array::make_empty_array_reader; @@ -315,7 +315,12 @@ fn build_struct_reader( if let Some(reader) = build_reader(parquet, mask, row_groups)? { // Need to retrieve underlying data type to handle projection let child_type = reader.get_data_type().clone(); - builder.push(arrow.as_ref().clone().with_data_type(child_type)); + // in case the user has provided the reference schema, if could reinterpret, we use the provied schema + if can_reinterpret(&child_type, arrow.data_type()) { + builder.push(arrow.as_ref().clone()) + } else { + builder.push(arrow.as_ref().clone().with_data_type(child_type)); + } readers.push(reader); } } diff --git a/parquet/src/arrow/arrow_reader/mod.rs b/parquet/src/arrow/arrow_reader/mod.rs index 78d0fd6da8a9..5615338cd01b 100644 --- a/parquet/src/arrow/arrow_reader/mod.rs +++ b/parquet/src/arrow/arrow_reader/mod.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use arrow_array::cast::AsArray; use arrow_array::Array; use arrow_array::{RecordBatch, RecordBatchReader}; -use arrow_schema::{ArrowError, DataType as ArrowType, Schema, SchemaRef}; +use arrow_schema::{can_reinterpret, ArrowError, DataType as ArrowType, Schema, SchemaRef}; use arrow_select::filter::prep_null_mask_filter; use crate::arrow::array_reader::{build_array_reader, ArrayReader}; @@ -33,10 +33,8 @@ use crate::errors::{ParquetError, Result}; use crate::file::metadata::ParquetMetaData; use crate::file::reader::{ChunkReader, SerializedPageReader}; use crate::schema::types::SchemaDescriptor; - mod filter; mod selection; - pub use crate::arrow::array_reader::RowGroups; use crate::column::page::{PageIterator, PageReader}; use crate::file::footer; @@ -192,6 +190,81 @@ impl ArrowReaderBuilder { ..self } } + pub fn with_new_schema(self, new_schema: SchemaRef) -> Self { + Self { + schema: new_schema, + ..self + } + } + + pub fn with_new_parquet_field(self, parquet_field: Option>) -> Self { + Self { + fields: parquet_field, + ..self + } + } + /// specify the arrow schema to read from this parquet file + /// will error if the types in the parquet file can not be converted + /// into the specific types. + /// Will ignore any embedded metadata about types when written + pub fn with_reinterpret_schema(self, new_schema: SchemaRef) -> Self { + // // Check if self.fields is Some and if it contains a DataType::Struct + if let Some(field_ref) = &self.fields { + match &field_ref.arrow_type { + arrow_schema::DataType::Struct(existing_fields) => { + // Retrieve all fields from the new_schema + let all_fields = new_schema.fields(); + + // Check if all fields in the new_schema can be cast with existing_fields + if all_fields.len() == existing_fields.len() + && all_fields.iter().zip(existing_fields.iter()).all( + |(new_field, existing_field)| { + // Compare field names and determine if types can be cast + new_field.name() == existing_field.name() + && can_reinterpret( + existing_field.data_type(), + new_field.data_type(), + ) + }, + ) + { + let new_data_types = + arrow_schema::DataType::Struct(new_schema.fields.clone()); + // If all checks pass, update the schema and the arrow_type of the fields + let new_parquet_field = ParquetField { + rep_level: field_ref.rep_level, + def_level: field_ref.def_level, + nullable: field_ref.nullable, + arrow_type: new_data_types, + field_type: field_ref.field_type.clone(), + }; + return self + .with_new_schema(new_schema) + .with_new_parquet_field(Some(Arc::new(new_parquet_field))); + } + } + // supposed to be a primitive type + other_types => { + let all_fields = new_schema.all_fields(); + assert!(all_fields.len() == 1); + if can_reinterpret(other_types, all_fields[0].data_type()) { + let new_parquet_field = ParquetField { + rep_level: field_ref.rep_level, + def_level: field_ref.def_level, + nullable: field_ref.nullable, + arrow_type: all_fields[0].data_type().clone(), + field_type: field_ref.field_type.clone(), + }; + return self + .with_new_schema(new_schema) + .with_new_parquet_field(Some(Arc::new(new_parquet_field))); + } + } + } + } + // Return self without changes if fields don't match or fields are not a struct + self + } } /// Options that control how metadata is read for a parquet file @@ -636,7 +709,6 @@ impl ParquetRecordBatchReader { ArrowType::Struct(ref fields) => Schema::new(fields.clone()), _ => unreachable!("Struct array reader's data type is not struct!"), }; - Self { batch_size, array_reader, @@ -752,7 +824,7 @@ mod tests { use arrow_array::*; use arrow_buffer::{i256, ArrowNativeType, Buffer}; use arrow_data::ArrayDataBuilder; - use arrow_schema::{ArrowError, DataType as ArrowDataType, Field, Fields, Schema}; + use arrow_schema::{ArrowError, DataType as ArrowDataType, Field, Fields, Schema, TimeUnit}; use arrow_select::concat::concat_batches; use crate::arrow::arrow_reader::{ @@ -3234,4 +3306,51 @@ mod tests { } } } + #[test] + fn test_reinterpret_timestamp() -> Result<()> { + use arrow_schema::DataType; + use std::collections::HashMap; + let testdata = arrow::util::test_util::parquet_test_data(); + let path = format!("{testdata}/alltypes_plain.parquet"); + let file = File::open(path).unwrap(); + let fields = vec![ + Field::new("id", DataType::Int32, true), + Field::new("bool_col", DataType::Boolean, true), + Field::new("tinyint_col", DataType::Int32, true), + Field::new("smallint_col", DataType::Int32, true), + Field::new("int_col", DataType::Int32, true), + Field::new("bigint_col", DataType::Int64, true), + Field::new("float_col", DataType::Float32, true), + Field::new("double_col", DataType::Float64, true), + Field::new("date_string_col", DataType::Binary, true), + Field::new("string_col", DataType::Binary, true), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:00"))), + true, + ), + ]; + let new_schema = Arc::new(Schema::new_with_metadata(fields, HashMap::new())); + let builder = ParquetRecordBatchReaderBuilder::try_new(file)?.with_batch_size(8192); + assert_eq!( + builder.schema().field(10).clone(), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true + ) + ); + let builder = builder.with_reinterpret_schema(new_schema); + let parquet_reader = builder.build()?; + let schema = parquet_reader.schema; + assert_eq!( + schema.field(10).clone(), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Nanosecond, Some(Arc::from("+07:00"))), + true, + ) + ); + Ok(()) + } }