From 0392dd97f04b97ef23dcffd48df304bc77b710ec Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Sun, 19 Nov 2023 16:04:27 +1100 Subject: [PATCH] Refactor schema/type handling (#45) * Refactor schema/type handling * Removed unused dependency --- Cargo.toml | 1 - src/arrow_reader.rs | 91 ++++---- src/arrow_reader/column.rs | 166 +++++++------ src/async_arrow_reader.rs | 18 +- src/lib.rs | 1 + src/reader.rs | 7 +- src/reader/metadata.rs | 13 +- src/reader/schema.rs | 328 -------------------------- src/schema.rs | 462 +++++++++++++++++++++++++++++++++++++ 9 files changed, 613 insertions(+), 474 deletions(-) delete mode 100644 src/reader/schema.rs create mode 100644 src/schema.rs diff --git a/Cargo.toml b/Cargo.toml index 77693110..93d08d95 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ fallible-streaming-iterator = { version = "0.1" } flate2 = "1" futures = { version = "0.3", default-features = false, features = ["std"] } futures-util = "0.3" -lazy_static = "1.4" lz4_flex = "0.11" lzokay-native = "0.1" paste = "1.0" diff --git a/src/arrow_reader.rs b/src/arrow_reader.rs index 801b7098..34388221 100644 --- a/src/arrow_reader.rs +++ b/src/arrow_reader.rs @@ -40,8 +40,8 @@ use crate::error::{self, InvalidColumnSnafu, Result}; use crate::proto::stream::Kind; use crate::proto::StripeFooter; use crate::reader::decompress::{Compression, Decompressor}; -use crate::reader::schema::{create_field, TypeDescription}; use crate::reader::Reader; +use crate::schema::{DataType, RootDataType}; use crate::stripe::StripeMetadata; pub struct ArrowReader { @@ -102,14 +102,7 @@ pub fn create_arrow_schema(cursor: &Cursor) -> Schema { .iter() .map(|(key, value)| (key.clone(), String::from_utf8_lossy(value).to_string())) .collect::>(); - - let fields = cursor - .columns - .iter() - .map(|(name, typ)| Arc::new(create_field((name, typ)))) - .collect::>(); - - Schema::new_with_metadata(fields, metadata) + cursor.root_data_type.create_arrow_schema(&metadata) } impl RecordBatchReader for ArrowReader { @@ -726,28 +719,26 @@ pub trait BatchDecoder: Send { } pub fn reader_factory(col: &Column, stripe: &Stripe) -> Result { - let reader = match col.kind() { - crate::proto::r#type::Kind::Boolean => Decoder::Boolean(new_boolean_iter(col, stripe)?), - crate::proto::r#type::Kind::Byte => Decoder::Int8(new_i8_iter(col, stripe)?), - crate::proto::r#type::Kind::Short => Decoder::Int16(new_i64_iter(col, stripe)?), - crate::proto::r#type::Kind::Int => Decoder::Int32(new_i64_iter(col, stripe)?), - crate::proto::r#type::Kind::Long => Decoder::Int64(new_i64_iter(col, stripe)?), - crate::proto::r#type::Kind::Float => Decoder::Float32(new_f32_iter(col, stripe)?), - crate::proto::r#type::Kind::Double => Decoder::Float64(new_f64_iter(col, stripe)?), - crate::proto::r#type::Kind::String => Decoder::String(StringDecoder::new(col, stripe)?), - crate::proto::r#type::Kind::Binary => Decoder::Binary(new_binary_iterator(col, stripe)?), - crate::proto::r#type::Kind::Timestamp => { - Decoder::Timestamp(new_timestamp_iter(col, stripe)?) - } - crate::proto::r#type::Kind::List => Decoder::List(new_list_iter(col, stripe)?), - crate::proto::r#type::Kind::Map => Decoder::Map(new_map_iter(col, stripe)?), - crate::proto::r#type::Kind::Struct => Decoder::Struct(new_struct_iter(col, stripe)?), - crate::proto::r#type::Kind::Union => todo!(), - crate::proto::r#type::Kind::Decimal => todo!(), - crate::proto::r#type::Kind::Date => Decoder::Date(new_i64_iter(col, stripe)?), - crate::proto::r#type::Kind::Varchar => Decoder::String(StringDecoder::new(col, stripe)?), - crate::proto::r#type::Kind::Char => Decoder::String(StringDecoder::new(col, stripe)?), - crate::proto::r#type::Kind::TimestampInstant => todo!(), + let reader = match col.data_type() { + DataType::Boolean { .. } => Decoder::Boolean(new_boolean_iter(col, stripe)?), + DataType::Byte { .. } => Decoder::Int8(new_i8_iter(col, stripe)?), + DataType::Short { .. } => Decoder::Int16(new_i64_iter(col, stripe)?), + DataType::Int { .. } => Decoder::Int32(new_i64_iter(col, stripe)?), + DataType::Long { .. } => Decoder::Int64(new_i64_iter(col, stripe)?), + DataType::Float { .. } => Decoder::Float32(new_f32_iter(col, stripe)?), + DataType::Double { .. } => Decoder::Float64(new_f64_iter(col, stripe)?), + DataType::String { .. } => Decoder::String(StringDecoder::new(col, stripe)?), + DataType::Binary { .. } => Decoder::Binary(new_binary_iterator(col, stripe)?), + DataType::Timestamp { .. } => Decoder::Timestamp(new_timestamp_iter(col, stripe)?), + DataType::List { .. } => Decoder::List(new_list_iter(col, stripe)?), + DataType::Map { .. } => Decoder::Map(new_map_iter(col, stripe)?), + DataType::Struct { .. } => Decoder::Struct(new_struct_iter(col, stripe)?), + DataType::Union { .. } => todo!(), + DataType::Decimal { .. } => todo!(), + DataType::Date { .. } => Decoder::Date(new_i64_iter(col, stripe)?), + DataType::Varchar { .. } => Decoder::String(StringDecoder::new(col, stripe)?), + DataType::Char { .. } => Decoder::String(StringDecoder::new(col, stripe)?), + DataType::TimestampWithLocalTimezone { .. } => todo!(), }; Ok(reader) @@ -816,35 +807,25 @@ impl NaiveStripeDecoder { pub struct Cursor { pub(crate) reader: Reader, - pub(crate) columns: Arc)>>, + pub(crate) root_data_type: RootDataType, pub(crate) stripe_offset: usize, } impl Cursor { pub fn new>(r: Reader, fields: &[T]) -> Result { - let mut columns = Vec::with_capacity(fields.len()); - for name in fields { - let field = r - .metadata() - .type_description() - .field(name.as_ref()) - .context(error::FieldNotFoundSnafu { - name: name.as_ref(), - })?; - columns.push((name.as_ref().to_string(), field)); - } + let projected_data_type = r.metadata().root_data_type().project(fields); Ok(Self { reader: r, - columns: Arc::new(columns), + root_data_type: projected_data_type, stripe_offset: 0, }) } pub fn root(r: Reader) -> Result { - let columns = r.metadata().type_description().children(); + let data_type = r.metadata().root_data_type().clone(); Ok(Self { reader: r, - columns: Arc::new(columns), + root_data_type: data_type, stripe_offset: 0, }) } @@ -855,7 +836,12 @@ impl Iterator for Cursor { fn next(&mut self) -> Option { if let Some(info) = self.reader.stripe(self.stripe_offset).cloned() { - let stripe = Stripe::new(&mut self.reader, &self.columns, self.stripe_offset, &info); + let stripe = Stripe::new( + &mut self.reader, + self.root_data_type.clone(), + self.stripe_offset, + &info, + ); self.stripe_offset += 1; @@ -901,7 +887,7 @@ impl StreamMap { impl Stripe { pub fn new( r: &mut Reader, - column_defs: &[(String, Arc)], + root_data_type: RootDataType, stripe: usize, info: &StripeMetadata, ) -> Result { @@ -909,10 +895,11 @@ impl Stripe { let compression = r.metadata().compression(); //TODO(weny): add tz - let mut columns = Vec::with_capacity(column_defs.len()); - for (name, typ) in column_defs.iter() { - columns.push(Column::new(name, typ, &footer, info.number_of_rows())); - } + let columns = root_data_type + .children() + .iter() + .map(|(name, data_type)| Column::new(name, data_type, &footer, info.number_of_rows())) + .collect(); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); diff --git a/src/arrow_reader/column.rs b/src/arrow_reader/column.rs index 53a242c1..89bed071 100644 --- a/src/arrow_reader/column.rs +++ b/src/arrow_reader/column.rs @@ -3,15 +3,13 @@ use std::sync::Arc; use arrow::datatypes::Field; use bytes::Bytes; -use snafu::{OptionExt, ResultExt}; +use snafu::ResultExt; use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use crate::error::{self, Result}; -use crate::proto::stream::Kind; -use crate::proto::{ColumnEncoding, StripeFooter, StripeInformation}; - -use crate::reader::schema::{create_field, TypeDescription}; +use crate::proto::{ColumnEncoding, StripeFooter}; use crate::reader::Reader; +use crate::schema::DataType; pub mod binary; pub mod boolean; @@ -30,18 +28,20 @@ pub struct Column { number_of_rows: u64, footer: Arc, name: String, - column: Arc, + data_type: DataType, } impl From for Field { fn from(value: Column) -> Self { - create_field((&value.name, &value.column)) + let dt = value.data_type.to_arrow_data_type(); + Field::new(value.name, dt, true) } } impl From<&Column> for Field { fn from(value: &Column) -> Self { - create_field((&value.name, &value.column)) + let dt = value.data_type.to_arrow_data_type(); + Field::new(value.name.clone(), dt, true) } } @@ -64,76 +64,29 @@ macro_rules! impl_read_stream { } impl Column { - pub fn read_stream( - reader: &mut Reader, - start: u64, - length: usize, - ) -> Result { - impl_read_stream!(reader, start, length) - } - - pub async fn read_stream_async( - reader: &mut Reader, - start: u64, - length: usize, - ) -> Result { - impl_read_stream!(reader, start, length.await) - } - - pub fn get_stream_info( - name: &str, - column: &Arc, - footer: &Arc, - stripe: &StripeInformation, - ) -> Result<(u64, usize)> { - let mut start = 0; // the start of the stream - - let column_idx = column.column_id() as u32; - - let start = footer - .streams - .iter() - .map(|stream| { - start += stream.length(); - (start, stream) - }) - .find(|(_, stream)| stream.column() == column_idx && stream.kind() != Kind::RowIndex) - .map(|(start, stream)| start - stream.length()) - .with_context(|| error::InvalidColumnSnafu { name })?; - - let length = footer - .streams - .iter() - .filter(|stream| stream.column() == column_idx && stream.kind() != Kind::RowIndex) - .fold(0, |acc, stream| acc + stream.length()) as usize; - let start = stripe.offset() + start; - - Ok((start, length)) - } - pub fn new( name: &str, - column: &Arc, + data_type: &DataType, footer: &Arc, number_of_rows: u64, ) -> Self { Self { number_of_rows, footer: footer.clone(), - column: column.clone(), + data_type: data_type.clone(), name: name.to_string(), } } pub fn dictionary_size(&self) -> usize { - let column = self.column.column_id(); + let column = self.data_type.column_index(); self.footer.columns[column] .dictionary_size .unwrap_or_default() as usize } pub fn encoding(&self) -> ColumnEncoding { - let column = self.column.column_id(); + let column = self.data_type.column_index(); self.footer.columns[column].clone() } @@ -141,8 +94,8 @@ impl Column { self.number_of_rows as usize } - pub fn kind(&self) -> crate::proto::r#type::Kind { - self.column.kind() + pub fn data_type(&self) -> &DataType { + &self.data_type } pub fn name(&self) -> &str { @@ -150,24 +103,89 @@ impl Column { } pub fn column_id(&self) -> u32 { - self.column.column_id() as u32 + self.data_type.column_index() as u32 } pub fn children(&self) -> Vec { - let children = self.column.children(); - - let mut columns = Vec::with_capacity(children.len()); - - for (name, column) in children { - columns.push(Column { - number_of_rows: self.number_of_rows, - footer: self.footer.clone(), - name, - column, - }); + match &self.data_type { + DataType::Boolean { .. } + | DataType::Byte { .. } + | DataType::Short { .. } + | DataType::Int { .. } + | DataType::Long { .. } + | DataType::Float { .. } + | DataType::Double { .. } + | DataType::String { .. } + | DataType::Varchar { .. } + | DataType::Char { .. } + | DataType::Binary { .. } + | DataType::Decimal { .. } + | DataType::Timestamp { .. } + | DataType::TimestampWithLocalTimezone { .. } + | DataType::Date { .. } => vec![], + DataType::Struct { children, .. } => children + .iter() + .map(|(name, data_type)| Column { + number_of_rows: self.number_of_rows, + footer: self.footer.clone(), + name: name.clone(), + data_type: data_type.clone(), + }) + .collect(), + DataType::List { child, .. } => { + vec![Column { + number_of_rows: self.number_of_rows, + footer: self.footer.clone(), + name: "item".to_string(), + data_type: *child.clone(), + }] + } + DataType::Map { key, value, .. } => { + vec![ + Column { + number_of_rows: self.number_of_rows, + footer: self.footer.clone(), + name: "key".to_string(), + data_type: *key.clone(), + }, + Column { + number_of_rows: self.number_of_rows, + footer: self.footer.clone(), + name: "value".to_string(), + data_type: *value.clone(), + }, + ] + } + DataType::Union { variants, .. } => { + // TODO: might need corrections + variants + .iter() + .enumerate() + .map(|(index, data_type)| Column { + number_of_rows: self.number_of_rows, + footer: self.footer.clone(), + name: format!("{index}"), + data_type: data_type.clone(), + }) + .collect() + } } + } + + pub fn read_stream( + reader: &mut Reader, + start: u64, + length: usize, + ) -> Result { + impl_read_stream!(reader, start, length) + } - columns + pub async fn read_stream_async( + reader: &mut Reader, + start: u64, + length: usize, + ) -> Result { + impl_read_stream!(reader, start, length.await) } } diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index 6dd7082e..37450695 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -17,8 +17,8 @@ use crate::arrow_reader::{ create_arrow_schema, Cursor, NaiveStripeDecoder, StreamMap, Stripe, DEFAULT_BATCH_SIZE, }; use crate::error::Result; -use crate::reader::schema::TypeDescription; use crate::reader::Reader; +use crate::schema::RootDataType; use crate::stripe::StripeMetadata; pub type BoxedDecoder = Box> + Send>; @@ -70,11 +70,11 @@ impl StripeFactory { pub async fn read_next_stripe_inner(&mut self, info: &StripeMetadata) -> Result { let inner = &mut self.inner; - let column_defs = inner.columns.clone(); + let root_data_type = inner.root_data_type.clone(); let stripe_offset = inner.stripe_offset; inner.stripe_offset += 1; - Stripe::new_async(&mut inner.reader, column_defs, stripe_offset, info).await + Stripe::new_async(&mut inner.reader, root_data_type, stripe_offset, info).await } pub async fn read_next_stripe(mut self) -> Result<(Self, Option)> { @@ -181,9 +181,10 @@ impl Stream for ArrowStreamRe } impl Stripe { + // TODO: reduce duplication with sync version in arrow_reader.rs pub async fn new_async( r: &mut Reader, - column_defs: Arc)>>, + root_data_type: RootDataType, stripe: usize, info: &StripeMetadata, ) -> Result { @@ -191,10 +192,11 @@ impl Stripe { let compression = r.metadata().compression(); //TODO(weny): add tz - let mut columns = Vec::with_capacity(column_defs.len()); - for (name, typ) in column_defs.iter() { - columns.push(Column::new(name, typ, &footer, info.number_of_rows())); - } + let columns = root_data_type + .children() + .iter() + .map(|(name, data_type)| Column::new(name, data_type, &footer, info.number_of_rows())) + .collect(); let mut stream_map = HashMap::new(); let mut stream_offset = info.offset(); diff --git a/src/lib.rs b/src/lib.rs index 1ab9f46b..915ae156 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub(crate) mod builder; pub mod error; pub mod proto; pub mod reader; +pub mod schema; pub mod statistics; pub mod stripe; diff --git a/src/reader.rs b/src/reader.rs index e7f6ea2e..23722964 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,7 +1,6 @@ pub mod decode; pub mod decompress; pub mod metadata; -pub mod schema; use std::fs::File; use std::io::{BufReader, Read, Seek, SeekFrom}; @@ -9,10 +8,10 @@ use std::io::{BufReader, Read, Seek, SeekFrom}; use tokio::io::{AsyncRead, AsyncSeek}; use self::metadata::{read_metadata, FileMetadata}; -use self::schema::TypeDescription; use crate::error::Result; use crate::proto::StripeFooter; use crate::reader::metadata::read_metadata_async; +use crate::schema::RootDataType; use crate::stripe::StripeMetadata; pub struct Reader { @@ -32,8 +31,8 @@ impl Reader { &self.metadata } - pub fn schema(&self) -> &TypeDescription { - self.metadata.type_description() + pub fn schema(&self) -> &RootDataType { + self.metadata.root_data_type() } pub fn stripe(&self, index: usize) -> Option<&StripeMetadata> { diff --git a/src/reader/metadata.rs b/src/reader/metadata.rs index 844f938e..f07ad543 100644 --- a/src/reader/metadata.rs +++ b/src/reader/metadata.rs @@ -24,7 +24,6 @@ use std::collections::HashMap; use std::io::{Read, SeekFrom}; -use std::sync::Arc; use bytes::Bytes; use prost::Message; @@ -34,11 +33,11 @@ use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt}; use crate::error::{self, Result}; use crate::proto::{self, Footer, Metadata, PostScript, StripeFooter}; use crate::reader::decompress::Decompressor; +use crate::schema::RootDataType; use crate::statistics::ColumnStatistics; use crate::stripe::StripeMetadata; use super::decompress::Compression; -use super::schema::{create_schema, TypeDescription}; use super::ChunkReader; const DEFAULT_FOOTER_SIZE: u64 = 16 * 1024; @@ -47,7 +46,7 @@ const DEFAULT_FOOTER_SIZE: u64 = 16 * 1024; #[derive(Debug)] pub struct FileMetadata { compression: Option, - type_description: Arc, + root_data_type: RootDataType, number_of_rows: u64, /// Statistics of columns across entire file column_statistics: Vec, @@ -67,7 +66,7 @@ impl FileMetadata { ) -> Result { let compression = Compression::from_proto(postscript.compression(), postscript.compression_block_size); - let type_description = create_schema(&footer.types, 0)?; + let root_data_type = RootDataType::from_proto(&footer.types)?; let number_of_rows = footer.number_of_rows(); let column_statistics = footer .statistics @@ -88,7 +87,7 @@ impl FileMetadata { Ok(Self { compression, - type_description, + root_data_type, number_of_rows, column_statistics, stripes, @@ -105,8 +104,8 @@ impl FileMetadata { self.compression } - pub fn type_description(&self) -> &Arc { - &self.type_description + pub fn root_data_type(&self) -> &RootDataType { + &self.root_data_type } pub fn column_file_statistics(&self) -> &[ColumnStatistics] { diff --git a/src/reader/schema.rs b/src/reader/schema.rs deleted file mode 100644 index c5d1b862..00000000 --- a/src/reader/schema.rs +++ /dev/null @@ -1,328 +0,0 @@ -use std::sync::{Arc, Mutex, Weak}; - -use arrow::datatypes::{DataType, Field, Fields, UnionFields, UnionMode}; -use lazy_static::lazy_static; -use snafu::ensure; - -use crate::error::{self, Result}; -use crate::proto::r#type::Kind; -use crate::proto::Type; - -#[derive(Debug, Clone)] -pub struct Category { - name: String, - is_primitive: bool, - kind: Kind, -} - -impl Category { - pub fn new(name: &str, is_primitive: bool, kind: Kind) -> Self { - Self { - name: name.to_string(), - is_primitive, - kind, - } - } - - pub fn primitive(&self) -> bool { - self.is_primitive - } - - pub fn name(&self) -> &str { - &self.name - } -} - -pub fn create_field((name, typ): (&str, &Arc)) -> Field { - let kind = typ.kind(); - match kind { - Kind::Boolean => Field::new(name, DataType::Boolean, true), - Kind::Byte => Field::new(name, DataType::Int8, true), - Kind::Short => Field::new(name, DataType::Int16, true), - Kind::Int => Field::new(name, DataType::Int32, true), - Kind::Long => Field::new(name, DataType::Int64, true), - Kind::Float => Field::new(name, DataType::Float32, true), - Kind::Double => Field::new(name, DataType::Float64, true), - Kind::String => Field::new(name, DataType::Utf8, true), - Kind::Binary => Field::new(name, DataType::LargeBinary, true), - // TODO(weny): handle tz - Kind::Timestamp => Field::new( - name, - DataType::Timestamp(arrow::datatypes::TimeUnit::Nanosecond, None), - true, - ), - Kind::List => { - let children = typ.children(); - assert_eq!(children.len(), 1); - - let (_, typ) = &children[0]; - let value = create_field((name, typ)); - - Field::new(name, DataType::List(Arc::new(value)), true) - } - Kind::Map => { - let children = typ.children(); - assert_eq!(children.len(), 2); - - let (_, typ) = &children[0]; - let key = create_field(("key", typ)); - - let (_, typ) = &children[1]; - let value = create_field(("value", typ)); - let fields = vec![key, value]; - - let data_type = DataType::Struct(Fields::from(fields)); - let field = Field::new(name, data_type, true); - Field::new(name, DataType::Map(Arc::new(field), false), true) - } - Kind::Struct => { - let children = typ.children(); - let mut fields = Vec::with_capacity(children.len()); - for (name, child) in &children { - fields.push(create_field((name, child))); - } - - Field::new(name, DataType::Struct(Fields::from(fields)), true) - } - Kind::Union => { - let children = typ.children(); - let mut fields = Vec::with_capacity(children.len()); - for (idx, (name, child)) in children.iter().enumerate() { - fields.push((idx as i8, Arc::new(create_field((name, child))))); - } - - Field::new( - name, - DataType::Union(UnionFields::from_iter(fields), UnionMode::Sparse), - true, - ) - } - Kind::Decimal => { - let inner = typ.inner.lock().unwrap(); - Field::new( - name, - DataType::Decimal128(inner.precision as u8, inner.scale as i8), - true, - ) - } - Kind::Date => Field::new(name, DataType::Date32, true), - Kind::Varchar => Field::new(name, DataType::Utf8, true), - Kind::Char => Field::new(name, DataType::Utf8, true), - Kind::TimestampInstant => todo!(), - } -} - -lazy_static! { - static ref BOOLEAN: Category = Category::new("boolean", true, Kind::Boolean); - static ref TINYINT: Category = Category::new("tinyint", true, Kind::Byte); - static ref SMALLINT: Category = Category::new("smallint", true, Kind::Short); - static ref INT: Category = Category::new("int", true, Kind::Int); - static ref BIGINT: Category = Category::new("bigint", true, Kind::Long); - static ref FLOAT: Category = Category::new("float", true, Kind::Float); - static ref DOUBLE: Category = Category::new("double", true, Kind::Double); - static ref STRING: Category = Category::new("string", true, Kind::String); - static ref DATE: Category = Category::new("date", true, Kind::Date); - static ref TIMESTAMP: Category = Category::new("timestamp", true, Kind::Timestamp); - static ref BINARY: Category = Category::new("binary", true, Kind::Binary); - static ref DECIMAL: Category = Category::new("decimal", true, Kind::Decimal); - static ref VARCHAR: Category = Category::new("varchar", true, Kind::Varchar); - static ref CHAR: Category = Category::new("char", true, Kind::Char); - static ref ARRAY: Category = Category::new("array", false, Kind::List); - static ref MAP: Category = Category::new("map", false, Kind::Map); - static ref STRUCT: Category = Category::new("struct", false, Kind::Struct); - static ref UNIONTYPE: Category = Category::new("uniontype", false, Kind::Union); -} - -#[derive(Debug)] -pub struct TypeDescription { - inner: Mutex, -} - -impl TypeDescription { - pub fn new(category: Category, column: usize) -> Self { - Self { - inner: Mutex::new(TypeDescriptionInner::new(category, column)), - } - } - - pub fn set_parent(self: &Arc, parent: Weak) { - self.inner.lock().unwrap().set_parent(parent); - } - - pub fn add_field(self: &Arc, name: String, td: Arc) { - let mut inner = self.inner.lock().unwrap(); - inner.add_field(name, td.clone()); - let parent = Arc::downgrade(self); - td.set_parent(parent); - } - - pub fn field(&self, name: &str) -> Option> { - self.inner.lock().unwrap().get_field(name) - } - - pub fn column_id(&self) -> usize { - self.inner.lock().unwrap().column - } - - pub fn kind(&self) -> Kind { - self.inner.lock().unwrap().category.kind - } - - pub fn children(&self) -> Vec<(String, Arc)> { - let inner = self.inner.lock().unwrap(); - - let children = inner.children.clone().unwrap_or_default(); - - let names = inner.field_names.clone(); - - names.into_iter().zip(children).collect() - } -} - -#[derive(Debug)] - -pub struct TypeDescriptionInner { - category: Category, - parent: Option>, - children: Option>>, - field_names: Vec, - precision: usize, - scale: usize, - // column index - column: usize, -} - -const DEFAULT_SCALE: usize = 10; -const DEFAULT_PRECISION: usize = 38; - -impl TypeDescriptionInner { - pub fn new(category: Category, column: usize) -> Self { - Self { - category, - parent: None, - children: None, - field_names: Vec::new(), - precision: DEFAULT_PRECISION, - scale: DEFAULT_SCALE, - column, - } - } - - pub fn set_parent(&mut self, parent: Weak) { - self.parent = Some(parent); - } - - pub fn add_field(&mut self, name: String, td: Arc) { - self.field_names.push(name); - if self.children.is_none() { - self.children = Some(Vec::new()); - } - self.children.as_mut().unwrap().push(td); - } - - pub fn get_field(&self, name: &str) -> Option> { - let idx = self.field_names.iter().position(|f| f.eq(name)); - idx.and_then(|idx| self.children.as_ref().unwrap().get(idx).cloned()) - } -} - -pub fn create_schema(types: &[Type], root_column: usize) -> Result> { - if types.is_empty() { - return error::NoTypesSnafu {}.fail(); - } - - let root = &types[root_column]; - - match root.kind() { - Kind::Struct => { - let td = Arc::new(TypeDescription::new(STRUCT.clone(), root_column)); - let sub_types = &root.subtypes; - let fields = &root.field_names; - for (idx, column) in sub_types.iter().enumerate() { - let child = create_schema(types, *column as usize)?; - td.add_field(fields[idx].to_string(), child); - } - Ok(td) - } - - Kind::Boolean => Ok(Arc::new(TypeDescription::new(BOOLEAN.clone(), root_column))), - - // 8,16,32,64 - Kind::Byte => Ok(Arc::new(TypeDescription::new(TINYINT.clone(), root_column))), - Kind::Short => Ok(Arc::new(TypeDescription::new( - SMALLINT.clone(), - root_column, - ))), - Kind::Int => Ok(Arc::new(TypeDescription::new(INT.clone(), root_column))), - Kind::Long => Ok(Arc::new(TypeDescription::new(BIGINT.clone(), root_column))), - - // f32/f64 - Kind::Float => Ok(Arc::new(TypeDescription::new(FLOAT.clone(), root_column))), - Kind::Double => Ok(Arc::new(TypeDescription::new(DOUBLE.clone(), root_column))), - - // String - Kind::String => Ok(Arc::new(TypeDescription::new(STRING.clone(), root_column))), - Kind::Varchar => Ok(Arc::new(TypeDescription::new(VARCHAR.clone(), root_column))), - Kind::Char => Ok(Arc::new(TypeDescription::new(CHAR.clone(), root_column))), - - // Timestamp/Date - Kind::Timestamp => Ok(Arc::new(TypeDescription::new( - TIMESTAMP.clone(), - root_column, - ))), - Kind::Date => Ok(Arc::new(TypeDescription::new(DATE.clone(), root_column))), - - // FIXME(weny): Test propose - Kind::Binary => Ok(Arc::new(TypeDescription::new(BINARY.clone(), root_column))), - Kind::List => { - let sub_types = &root.subtypes; - ensure!( - sub_types.len() == 1, - error::UnexpectedSnafu { - msg: format!("unexpected number of subtypes for list: {:?}", sub_types) - } - ); - - let td = Arc::new(TypeDescription::new(ARRAY.clone(), root_column)); - - let column = sub_types[0]; - let child = create_schema(types, column as usize)?; - // TODO(weny): remove dummy name. - td.add_field("root".to_string(), child); - - Ok(td) - } - Kind::Map => { - let sub_types = &root.subtypes; - ensure!( - sub_types.len() == 2, - error::UnexpectedSnafu { - msg: format!("unexpected number of subtypes for map: {:?}", sub_types) - } - ); - - let td = Arc::new(TypeDescription::new(MAP.clone(), root_column)); - let fields = &["key", "value"]; - for (idx, column) in sub_types.iter().enumerate() { - let child = create_schema(types, *column as usize)?; - td.add_field(fields[idx].to_string(), child); - } - - Ok(td) - } - Kind::Union => { - let td = Arc::new(TypeDescription::new(UNIONTYPE.clone(), root_column)); - - let sub_types = &root.subtypes; - let fields = &root.field_names; - for (idx, column) in sub_types.iter().enumerate() { - let child = create_schema(types, *column as usize)?; - td.add_field(fields[idx].to_string(), child); - } - - Ok(td) - } - Kind::Decimal => Ok(Arc::new(TypeDescription::new(DECIMAL.clone(), root_column))), - Kind::TimestampInstant => todo!(), - } -} diff --git a/src/schema.rs b/src/schema.rs new file mode 100644 index 00000000..8d6c96c8 --- /dev/null +++ b/src/schema.rs @@ -0,0 +1,462 @@ +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; + +use snafu::{ensure, OptionExt}; + +use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu}; +use crate::proto; + +use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode}; + +/// Represents the root data type of the ORC file. Contains multiple named child types +/// which map to the columns available. Allows projecting only specific columns from +/// the base schema. +/// +/// This is essentially a Struct type, but with special handling such as for projection +/// and transforming into an Arrow schema. +/// +/// Note that the ORC spec states the root type does not necessarily have to be a Struct. +/// Currently we only support having a Struct as the root data type. +/// +/// See: +#[derive(Debug, Clone)] +pub struct RootDataType { + children: Vec<(String, DataType)>, +} + +impl RootDataType { + /// Root column index is always 0. + pub fn column_index(&self) -> usize { + 0 + } + + /// Base columns of the file. + pub fn children(&self) -> &[(String, DataType)] { + &self.children + } + + /// Convert into an Arrow schema. + pub fn create_arrow_schema(&self, user_metadata: &HashMap) -> Schema { + let fields = self + .children + .iter() + .map(|(name, dt)| { + let dt = dt.to_arrow_data_type(); + Field::new(name, dt, true) + }) + .collect::>(); + Schema::new_with_metadata(fields, user_metadata.clone()) + } + + /// Project only specific columns from the root type by column name. + pub fn project>(&self, fields: &[T]) -> Self { + // TODO: change project to accept project mask (vec of bools) instead of relying on col names? + // TODO: be able to nest project? (i.e. project child struct data type) unsure if actually desirable + let fields = fields.iter().map(AsRef::as_ref).collect::>(); + let children = self + .children + .iter() + .filter(|c| fields.contains(&c.0.as_str())) + .map(|c| c.to_owned()) + .collect::>(); + Self { children } + } + + /// Construct from protobuf types. + pub fn from_proto(types: &[proto::Type]) -> Result { + ensure!(!types.is_empty(), NoTypesSnafu {}); + let children = parse_struct_children_from_proto(types, 0)?; + Ok(Self { children }) + } +} + +impl Display for RootDataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ROOT")?; + for child in &self.children { + write!(f, "\n {} {}", child.0, child.1)?; + } + Ok(()) + } +} + +/// Helper function since this is duplicated for [`RootDataType`] and [`DataType::Struct`] +/// parsing from proto. +fn parse_struct_children_from_proto( + types: &[proto::Type], + column_index: usize, +) -> Result> { + // These pre-conditions should always be upheld, especially as this is a private function + assert!(column_index < types.len()); + let ty = &types[column_index]; + assert!(ty.kind() == proto::r#type::Kind::Struct); + ensure!( + ty.subtypes.len() == ty.field_names.len(), + UnexpectedSnafu { + msg: format!( + "Struct type for column index {} must have matching lengths for subtypes and field names lists", + column_index, + ) + } + ); + let children = ty + .subtypes + .iter() + .zip(ty.field_names.iter()) + .map(|(&index, name)| { + let index = index as usize; + let name = name.to_owned(); + let dt = DataType::from_proto(types, index)?; + Ok((name, dt)) + }) + .collect::>>()?; + Ok(children) +} + +/// Represents the exact data types supported by ORC. +/// +/// Each variant holds the column index in order to associate the type +/// with the specific column data present in the stripes. +#[derive(Debug, Clone)] +pub enum DataType { + /// 1 bit packed data. + Boolean { column_index: usize }, + /// 8 bit integer, also called TinyInt. + Byte { column_index: usize }, + /// 16 bit integer, also called SmallInt. + Short { column_index: usize }, + /// 32 bit integer. + Int { column_index: usize }, + /// 64 bit integer, also called BigInt. + Long { column_index: usize }, + /// 32 bit floating-point number. + Float { column_index: usize }, + /// 64 bit floating-point number. + Double { column_index: usize }, + /// UTF-8 encoded strings. + String { column_index: usize }, + /// UTF-8 encoded strings, with an upper length limit on values. + Varchar { + column_index: usize, + max_length: u32, + }, + /// UTF-8 encoded strings, with an upper length limit on values. + Char { + column_index: usize, + max_length: u32, + }, + /// Arbitrary byte array values. + Binary { column_index: usize }, + /// Decimal numbers with a fixed precision and scale. + Decimal { + column_index: usize, + // TODO: narrow to u8 + precision: u32, + scale: u32, + }, + /// Represents specific date and time, down to the nanosecond, as offset + /// since 1st January 2015, with no timezone. + /// + /// The date and time represented by values of this column does not change + /// based on the reader's timezone. + Timestamp { column_index: usize }, + /// Represents specific date and time, down to the nanosecond, as offset + /// since 1st January 2015, with timezone. + /// + /// The date and time represented by values of this column changes based + /// on the reader's timezone (is a fixed instant in time). + TimestampWithLocalTimezone { column_index: usize }, + /// Represents specific date (without time) as days since the UNIX epoch + /// (1st January 1970 UTC). + Date { column_index: usize }, + /// Compound type with named child subtypes, representing a structured + /// collection of children types. + Struct { + column_index: usize, + children: Vec<(String, DataType)>, + }, + /// Compound type where each value in the column is a list of values + /// of another type, specified by the child type. + List { + column_index: usize, + child: Box, + }, + /// Compound type with two children subtypes, key and value, representing + /// key-value pairs for column values. + Map { + column_index: usize, + key: Box, + value: Box, + }, + /// Compound type which can represent multiple types of data within + /// the same column. + /// + /// It's variants represent which types it can be (where each value in + /// the column can only be one of these types). + Union { + column_index: usize, + variants: Vec, + }, +} + +impl DataType { + /// Retrieve the column index of this data type, used for getting the specific column + /// streams/statistics in the file. + pub fn column_index(&self) -> usize { + match self { + DataType::Boolean { column_index } => *column_index, + DataType::Byte { column_index } => *column_index, + DataType::Short { column_index } => *column_index, + DataType::Int { column_index } => *column_index, + DataType::Long { column_index } => *column_index, + DataType::Float { column_index } => *column_index, + DataType::Double { column_index } => *column_index, + DataType::String { column_index } => *column_index, + DataType::Varchar { column_index, .. } => *column_index, + DataType::Char { column_index, .. } => *column_index, + DataType::Binary { column_index } => *column_index, + DataType::Decimal { column_index, .. } => *column_index, + DataType::Timestamp { column_index } => *column_index, + DataType::TimestampWithLocalTimezone { column_index } => *column_index, + DataType::Date { column_index } => *column_index, + DataType::Struct { column_index, .. } => *column_index, + DataType::List { column_index, .. } => *column_index, + DataType::Map { column_index, .. } => *column_index, + DataType::Union { column_index, .. } => *column_index, + } + } + + fn from_proto(types: &[proto::Type], column_index: usize) -> Result { + let ty = types.get(column_index).context(UnexpectedSnafu { + msg: format!("Column index out of bounds: {column_index}"), + })?; + let dt = match ty.kind() { + proto::r#type::Kind::Boolean => Self::Boolean { column_index }, + proto::r#type::Kind::Byte => Self::Byte { column_index }, + proto::r#type::Kind::Short => Self::Short { column_index }, + proto::r#type::Kind::Int => Self::Int { column_index }, + proto::r#type::Kind::Long => Self::Long { column_index }, + proto::r#type::Kind::Float => Self::Float { column_index }, + proto::r#type::Kind::Double => Self::Double { column_index }, + proto::r#type::Kind::String => Self::String { column_index }, + proto::r#type::Kind::Binary => Self::Binary { column_index }, + proto::r#type::Kind::Timestamp => Self::Timestamp { column_index }, + proto::r#type::Kind::List => { + ensure!( + ty.subtypes.len() == 1, + UnexpectedSnafu { + msg: format!( + "List type for column index {} must have 1 sub type, found {}", + column_index, + ty.subtypes.len() + ) + } + ); + let child = ty.subtypes[0] as usize; + let child = Box::new(Self::from_proto(types, child)?); + Self::List { + column_index, + child, + } + } + proto::r#type::Kind::Map => { + ensure!( + ty.subtypes.len() == 2, + UnexpectedSnafu { + msg: format!( + "Map type for column index {} must have 2 sub types, found {}", + column_index, + ty.subtypes.len() + ) + } + ); + let key = ty.subtypes[0] as usize; + let key = Box::new(Self::from_proto(types, key)?); + let value = ty.subtypes[1] as usize; + let value = Box::new(Self::from_proto(types, value)?); + Self::Map { + column_index, + key, + value, + } + } + proto::r#type::Kind::Struct => { + let children = parse_struct_children_from_proto(types, column_index)?; + Self::Struct { + column_index, + children, + } + } + proto::r#type::Kind::Union => { + ensure!( + ty.subtypes.len() <= 256, + UnexpectedSnafu { + msg: format!( + "Union type for column index {} cannot exceed 256 variants, found {}", + column_index, + ty.subtypes.len() + ) + } + ); + let variants = ty + .subtypes + .iter() + .map(|&index| { + let index = index as usize; + Self::from_proto(types, index) + }) + .collect::>>()?; + Self::Union { + column_index, + variants, + } + } + proto::r#type::Kind::Decimal => Self::Decimal { + column_index, + precision: ty.precision(), + scale: ty.scale(), + }, + proto::r#type::Kind::Date => Self::Date { column_index }, + proto::r#type::Kind::Varchar => Self::Varchar { + column_index, + max_length: ty.maximum_length(), + }, + proto::r#type::Kind::Char => Self::Char { + column_index, + max_length: ty.maximum_length(), + }, + proto::r#type::Kind::TimestampInstant => { + Self::TimestampWithLocalTimezone { column_index } + } + }; + Ok(dt) + } + + pub fn to_arrow_data_type(&self) -> ArrowDataType { + match self { + DataType::Boolean { .. } => ArrowDataType::Boolean, + DataType::Byte { .. } => ArrowDataType::Int8, + DataType::Short { .. } => ArrowDataType::Int16, + DataType::Int { .. } => ArrowDataType::Int32, + DataType::Long { .. } => ArrowDataType::Int64, + DataType::Float { .. } => ArrowDataType::Float32, + DataType::Double { .. } => ArrowDataType::Float64, + DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => { + ArrowDataType::Utf8 + } + DataType::Binary { .. } => ArrowDataType::Binary, + DataType::Decimal { + precision, scale, .. + } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), + DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::TimestampWithLocalTimezone { .. } => { + // TODO: get writer timezone + ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) + } + DataType::Date { .. } => ArrowDataType::Date32, + DataType::Struct { children, .. } => { + let children = children + .iter() + .map(|(name, dt)| { + let dt = dt.to_arrow_data_type(); + Field::new(name, dt, true) + }) + .collect(); + ArrowDataType::Struct(children) + } + DataType::List { child, .. } => { + let child = child.to_arrow_data_type(); + ArrowDataType::new_list(child, true) + } + DataType::Map { key, value, .. } => { + let key = key.to_arrow_data_type(); + let key = Field::new("key", key, true); + let value = value.to_arrow_data_type(); + let value = Field::new("value", value, true); + + let dt = ArrowDataType::Struct(vec![key, value].into()); + let dt = Arc::new(Field::new("item", dt, true)); + ArrowDataType::Map(dt, true) + } + DataType::Union { variants, .. } => { + let fields = variants + .iter() + .enumerate() + .map(|(index, variant)| { + // Should be safe as limited to 256 variants total (in from_proto) + let index = index as u8 as i8; + let arrow_dt = variant.to_arrow_data_type(); + // Name shouldn't matter here (only ORC struct types give names to subtypes anyway) + let field = Arc::new(Field::new(format!("{index}"), arrow_dt, true)); + (index, field) + }) + .collect(); + ArrowDataType::Union(fields, UnionMode::Sparse) + } + } + } +} + +impl Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DataType::Boolean { column_index: _ } => write!(f, "BOOLEAN"), + DataType::Byte { column_index: _ } => write!(f, "BYTE"), + DataType::Short { column_index: _ } => write!(f, "SHORT"), + DataType::Int { column_index: _ } => write!(f, "INTEGER"), + DataType::Long { column_index: _ } => write!(f, "LONG"), + DataType::Float { column_index: _ } => write!(f, "FLOAT"), + DataType::Double { column_index: _ } => write!(f, "DOUBLE"), + DataType::String { column_index: _ } => write!(f, "STRING"), + DataType::Varchar { + column_index: _, + max_length, + } => write!(f, "VARCHAR({max_length})"), + DataType::Char { + column_index: _, + max_length, + } => write!(f, "CHAR({max_length})"), + DataType::Binary { column_index: _ } => write!(f, "BINARY"), + DataType::Decimal { + column_index: _, + precision, + scale, + } => write!(f, "DECIMAL({precision}, {scale})"), + DataType::Timestamp { column_index: _ } => write!(f, "TIMESTAMP"), + DataType::TimestampWithLocalTimezone { column_index: _ } => { + write!(f, "TIMESTAMP INSTANT") + } + DataType::Date { column_index: _ } => write!(f, "DATE"), + DataType::Struct { + column_index: _, + children, + } => { + write!(f, "STRUCT")?; + for child in children { + write!(f, "\n {} {}", child.0, child.1)?; + } + Ok(()) + } + DataType::List { + column_index: _, + child, + } => write!(f, "LIST\n {child}"), + DataType::Map { + column_index: _, + key, + value, + } => write!(f, "MAP\n {key}\n {value}"), + DataType::Union { + column_index: _, + variants, + } => { + write!(f, "UNION")?; + for variant in variants { + write!(f, "\n {variant}")?; + } + Ok(()) + } + } + } +}