diff --git a/benches/arrow_reader.rs b/benches/arrow_reader.rs index fc517d61..f4661880 100644 --- a/benches/arrow_reader.rs +++ b/benches/arrow_reader.rs @@ -18,7 +18,7 @@ use std::fs::File; use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion_orc::{ArrowReader, ArrowStreamReader, Cursor}; +use datafusion_orc::arrow_reader::ArrowReaderBuilder; use futures_util::TryStreamExt; fn basic_path(path: &str) -> String { @@ -43,25 +43,19 @@ async fn async_read_all() { let file = "demo-12-zlib.orc"; let file_path = basic_path(file); let f = tokio::fs::File::open(file_path).await.unwrap(); - - let cursor = Cursor::root_async(f).await.unwrap(); - - ArrowStreamReader::new(cursor, None) - .try_collect::>() + let reader = ArrowReaderBuilder::try_new_async(f) .await - .unwrap(); + .unwrap() + .build_async(); + let _ = reader.try_collect::>().await.unwrap(); } fn sync_read_all() { let file = "demo-12-zlib.orc"; let file_path = basic_path(file); let f = File::open(file_path).unwrap(); - - let cursor = Cursor::root(f).unwrap(); - - ArrowReader::new(cursor, None) - .collect::, _>>() - .unwrap(); + let reader = ArrowReaderBuilder::try_new(f).unwrap().build(); + let _ = reader.collect::, _>>().unwrap(); } fn criterion_benchmark(c: &mut Criterion) { diff --git a/src/arrow_reader.rs b/src/arrow_reader.rs index b488ed41..31e5391b 100644 --- a/src/arrow_reader.rs +++ b/src/arrow_reader.rs @@ -46,28 +46,99 @@ use crate::reader::metadata::{read_metadata, read_metadata_async, FileMetadata}; use crate::reader::{AsyncChunkReader, ChunkReader}; use crate::schema::{DataType, RootDataType}; use crate::stripe::StripeMetadata; +use crate::ArrowStreamReader; -pub struct ArrowReader { - cursor: Cursor, - schema_ref: SchemaRef, - current_stripe: Option>>>, +pub const DEFAULT_BATCH_SIZE: usize = 8192; + +pub struct ArrowReaderBuilder { + reader: R, + file_metadata: Arc, batch_size: usize, + projection: ProjectionMask, } -pub const DEFAULT_BATCH_SIZE: usize = 8192; - -impl ArrowReader { - pub fn new(cursor: Cursor, batch_size: Option) -> Self { - let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - let schema = Arc::new(create_arrow_schema(&cursor)); +impl ArrowReaderBuilder { + fn new(reader: R, file_metadata: Arc) -> Self { Self { + reader, + file_metadata, + batch_size: DEFAULT_BATCH_SIZE, + projection: ProjectionMask::all(), + } + } + + pub fn file_metadata(&self) -> &FileMetadata { + &self.file_metadata + } + + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + pub fn with_projection(mut self, projection: ProjectionMask) -> Self { + self.projection = projection; + self + } +} + +impl ArrowReaderBuilder { + pub fn try_new(mut reader: R) -> Result { + let file_metadata = Arc::new(read_metadata(&mut reader)?); + Ok(Self::new(reader, file_metadata)) + } + + pub fn build(self) -> ArrowReader { + let projected_data_type = self + .file_metadata + .root_data_type() + .project(&self.projection); + let cursor = Cursor { + reader: self.reader, + file_metadata: self.file_metadata, + projected_data_type, + stripe_offset: 0, + }; + let schema_ref = Arc::new(create_arrow_schema(&cursor)); + ArrowReader { cursor, - schema_ref: schema, + schema_ref, current_stripe: None, - batch_size, + batch_size: self.batch_size, } } +} + +impl ArrowReaderBuilder { + pub async fn try_new_async(mut reader: R) -> Result { + let file_metadata = Arc::new(read_metadata_async(&mut reader).await?); + Ok(Self::new(reader, file_metadata)) + } + + pub fn build_async(self) -> ArrowStreamReader { + let projected_data_type = self + .file_metadata + .root_data_type() + .project(&self.projection); + let cursor = Cursor { + reader: self.reader, + file_metadata: self.file_metadata, + projected_data_type, + stripe_offset: 0, + }; + let schema_ref = Arc::new(create_arrow_schema(&cursor)); + ArrowStreamReader::new(cursor, self.batch_size, schema_ref) + } +} + +pub struct ArrowReader { + cursor: Cursor, + schema_ref: SchemaRef, + current_stripe: Option>>>, + batch_size: usize, +} +impl ArrowReader { pub fn total_row_count(&self) -> u64 { self.cursor.file_metadata.number_of_rows() } @@ -814,56 +885,6 @@ pub struct Cursor { pub(crate) stripe_offset: usize, } -impl Cursor { - pub fn new>(mut reader: R, fields: &[T]) -> Result { - let file_metadata = Arc::new(read_metadata(&mut reader)?); - let mask = ProjectionMask::named_roots(file_metadata.root_data_type(), fields); - let projected_data_type = file_metadata.root_data_type().project(&mask); - Ok(Self { - reader, - file_metadata, - projected_data_type, - stripe_offset: 0, - }) - } - - pub fn root(mut reader: R) -> Result { - let file_metadata = Arc::new(read_metadata(&mut reader)?); - let data_type = file_metadata.root_data_type().clone(); - Ok(Self { - reader, - file_metadata, - projected_data_type: data_type, - stripe_offset: 0, - }) - } -} - -impl Cursor { - pub async fn new_async>(mut reader: R, fields: &[T]) -> Result { - let file_metadata = Arc::new(read_metadata_async(&mut reader).await?); - let mask = ProjectionMask::named_roots(file_metadata.root_data_type(), fields); - let projected_data_type = file_metadata.root_data_type().project(&mask); - Ok(Self { - reader, - file_metadata, - projected_data_type, - stripe_offset: 0, - }) - } - - pub async fn root_async(mut reader: R) -> Result { - let file_metadata = Arc::new(read_metadata_async(&mut reader).await?); - let data_type = file_metadata.root_data_type().clone(); - Ok(Self { - reader, - file_metadata, - projected_data_type: data_type, - stripe_offset: 0, - }) - } -} - impl Iterator for Cursor { type Item = Result; diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index db120055..e5ae8287 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -14,8 +14,7 @@ use snafu::ResultExt; use crate::arrow_reader::column::Column; use crate::arrow_reader::{ - create_arrow_schema, deserialize_stripe_footer, Cursor, NaiveStripeDecoder, StreamMap, Stripe, - DEFAULT_BATCH_SIZE, + deserialize_stripe_footer, Cursor, NaiveStripeDecoder, StreamMap, Stripe, }; use crate::error::{IoSnafu, Result}; use crate::reader::metadata::FileMetadata; @@ -110,13 +109,11 @@ impl StripeFactory { } impl ArrowStreamReader { - pub fn new(c: Cursor, batch_size: Option) -> Self { - let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE); - let schema = Arc::new(create_arrow_schema(&c)); + pub fn new(cursor: Cursor, batch_size: usize, schema_ref: SchemaRef) -> Self { Self { - factory: Some(Box::new(c.into())), + factory: Some(Box::new(cursor.into())), batch_size, - schema_ref: schema, + schema_ref, state: StreamState::Init, } } diff --git a/src/lib.rs b/src/lib.rs index c2f6c1b7..40788a01 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,5 +9,5 @@ pub mod schema; pub mod statistics; pub mod stripe; -pub use arrow_reader::{ArrowReader, Cursor}; +pub use arrow_reader::{ArrowReader, ArrowReaderBuilder}; pub use async_arrow_reader::ArrowStreamReader; diff --git a/tests/basic/main.rs b/tests/basic/main.rs index fec74713..63c7016d 100644 --- a/tests/basic/main.rs +++ b/tests/basic/main.rs @@ -2,8 +2,9 @@ use std::fs::File; use arrow::record_batch::RecordBatch; use arrow::util::pretty; -use datafusion_orc::arrow_reader::{ArrowReader, Cursor}; +use datafusion_orc::arrow_reader::{ArrowReader, ArrowReaderBuilder}; use datafusion_orc::async_arrow_reader::ArrowStreamReader; +use datafusion_orc::projection::ProjectionMask; use futures_util::TryStreamExt; use crate::misc::{LONG_BOOL_EXPECTED, LONG_STRING_DICT_EXPECTED, LONG_STRING_EXPECTED}; @@ -12,26 +13,22 @@ mod misc; fn new_arrow_reader(path: &str, fields: &[&str]) -> ArrowReader { let f = File::open(path).expect("no file found"); - - let cursor = Cursor::new(f, fields).unwrap(); - - ArrowReader::new(cursor, None) + let builder = ArrowReaderBuilder::try_new(f).unwrap(); + let projection = ProjectionMask::named_roots(builder.file_metadata().root_data_type(), fields); + builder.with_projection(projection).build() } async fn new_arrow_stream_reader_root(path: &str) -> ArrowStreamReader { let f = tokio::fs::File::open(path).await.unwrap(); - - let cursor = Cursor::root_async(f).await.unwrap(); - - ArrowStreamReader::new(cursor, None) + ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .build_async() } fn new_arrow_reader_root(path: &str) -> ArrowReader { let f = File::open(path).expect("no file found"); - - let cursor = Cursor::root(f).unwrap(); - - ArrowReader::new(cursor, None) + ArrowReaderBuilder::try_new(f).unwrap().build() } fn basic_path(path: &str) -> String {