From 5606bbef5665f5036d7cf892eaa135445ea16538 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Mar 2024 21:33:45 -0700 Subject: [PATCH] feat: Convert predicate to arrow filter and push down to parquet reader --- Cargo.toml | 1 + crates/iceberg/Cargo.toml | 1 + crates/iceberg/src/arrow.rs | 431 ++++++++++++++++++++++++++- crates/iceberg/src/expr/predicate.rs | 45 ++- crates/iceberg/src/scan.rs | 150 +++++++++- crates/iceberg/src/spec/values.rs | 7 +- 6 files changed, 616 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7da16e00d..3552f8a69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ apache-avro = "0.16" array-init = "2" arrow-arith = { version = "51" } arrow-array = { version = "51" } +arrow-ord = { version = "51" } arrow-schema = { version = "51" } async-stream = "0.3.5" async-trait = "0.1" diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 5aea856fe..31139552d 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -34,6 +34,7 @@ apache-avro = { workspace = true } array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } diff --git a/crates/iceberg/src/arrow.rs b/crates/iceberg/src/arrow.rs index 527fb1917..810dc8892 100644 --- a/crates/iceberg/src/arrow.rs +++ b/crates/iceberg/src/arrow.rs @@ -20,17 +20,30 @@ use async_stream::try_stream; use futures::stream::StreamExt; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use std::collections::HashMap; use crate::io::FileIO; use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream}; -use crate::spec::SchemaRef; +use crate::spec::{Datum, PrimitiveLiteral, SchemaRef}; use crate::error::Result; +use crate::expr::{ + BinaryExpression, BoundPredicate, BoundReference, PredicateOperator, SetExpression, + UnaryExpression, +}; use crate::spec::{ ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, }; use crate::{Error, ErrorKind}; +use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; +use arrow_array::{ + BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, +}; +use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; +use bitvec::macros::internal::funty::Fundamental; +use parquet::arrow::arrow_reader::{ArrowPredicate, ArrowPredicateFn, RowFilter}; +use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; use std::sync::Arc; /// Builder to create ArrowReader @@ -38,6 +51,7 @@ pub struct ArrowReaderBuilder { batch_size: Option, file_io: FileIO, schema: SchemaRef, + predicates: Option>, } impl ArrowReaderBuilder { @@ -47,6 +61,7 @@ impl ArrowReaderBuilder { batch_size: None, file_io, schema, + predicates: None, } } @@ -57,12 +72,19 @@ impl ArrowReaderBuilder { self } + /// Sets the predicates to apply to the scan. + pub fn with_predicates(mut self, predicates: Vec) -> Self { + self.predicates = Some(predicates); + self + } + /// Build the ArrowReader. pub fn build(self) -> ArrowReader { ArrowReader { batch_size: self.batch_size, schema: self.schema, file_io: self.file_io, + predicates: self.predicates, } } } @@ -73,6 +95,7 @@ pub struct ArrowReader { #[allow(dead_code)] schema: SchemaRef, file_io: FileIO, + predicates: Option>, } impl ArrowReader { @@ -95,6 +118,13 @@ impl ArrowReader { .await? .with_projection(projection_mask); + let parquet_schema = batch_stream_builder.parquet_schema(); + let row_filter = self.get_row_filter(parquet_schema)?; + + if let Some(row_filter) = row_filter { + batch_stream_builder = batch_stream_builder.with_row_filter(row_filter); + } + if let Some(batch_size) = self.batch_size { batch_stream_builder = batch_stream_builder.with_batch_size(batch_size); } @@ -113,6 +143,405 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::>>() + }) + .collect::>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + columns: columns, + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec) -> Result { + Ok(()) + } + + fn or(&mut self, _predicates: Vec) -> Result { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub columns: &'a Vec, + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap, +} + +fn get_arrow_datum(datum: &Datum) -> Box { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Box::new(BooleanArray::new_scalar(*value)), + PrimitiveLiteral::Int(value) => Box::new(Int32Array::new_scalar(*value)), + PrimitiveLiteral::Long(value) => Box::new(Int64Array::new_scalar(*value)), + PrimitiveLiteral::Float(value) => Box::new(Float32Array::new_scalar(value.as_f32())), + PrimitiveLiteral::Double(value) => Box::new(Float64Array::new_scalar(value.as_f64())), + _ => todo!("Unsupported literal type"), + } +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box; + type U = usize; + + fn visit_always_true(&mut self) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") + } + PredicateOperator::NotNan => { + todo!("NotNan is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported unary operator: {op}"), + )), + } + } + + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result { + let term_index = self.bound_reference(predicate.term())?; + let literal = predicate.literal().clone(); + + match predicate.op() { + PredicateOperator::LessThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt(left, literal.as_ref()) + }, + ))), + PredicateOperator::LessThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + lt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt(left, literal.as_ref()) + }, + ))), + PredicateOperator::GreaterThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + gt_eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::Eq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + eq(left, literal.as_ref()) + }, + ))), + PredicateOperator::NotEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let literal = get_arrow_datum(&literal); + neq(left, literal.as_ref()) + }, + ))), + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported binary operator: {op}"), + )), + } + } + + fn visit_set(&mut self, predicate: &SetExpression) -> Result { + match predicate.op() { + PredicateOperator::In => { + todo!("In is not supported yet") + } + PredicateOperator::NotIn => { + todo!("NotIn is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported set operator: {op}"), + )), + } + } + + fn and(&mut self, mut predicates: Vec) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + and(&left, &right) + }, + ))) + } + + fn or(&mut self, mut predicates: Vec) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + or(&left, &right) + }, + ))) + } + + fn not(&mut self, mut predicate: Self::T) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let evaluated = predicate.evaluate(batch.clone())?; + not(&evaluated) + }, + ))) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result { + let column_idx = self.column_map.get(&reference.field().id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field id {} not found in schema", reference.field().id), + ) + })?; + + let root_col_index = self.parquet_schema.get_column_root_idx(*column_idx); + + // Find the column index in projection mask. + let column_idx = self + .columns + .iter() + .position(|&x| x == root_col_index) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Column index {} not found in schema", root_col_index), + ) + })?; + + Ok(column_idx) + } +} + +/// A visitor for bound predicates. +pub trait BoundPredicateVisitor { + /// Return type of this visitor on bound predicate. + type T; + + /// Return type of this visitor on bound reference. + type U; + + /// Visit a bound predicate. + fn visit_predicate(&mut self, predicate: &BoundPredicate) -> Result { + match predicate { + BoundPredicate::And(predicates) => self.visit_and(predicates.inputs()), + BoundPredicate::Or(predicates) => self.visit_or(predicates.inputs()), + BoundPredicate::Not(predicate) => self.visit_not(predicate.inputs()), + BoundPredicate::AlwaysTrue => self.visit_always_true(), + BoundPredicate::AlwaysFalse => self.visit_always_false(), + BoundPredicate::Unary(unary) => self.visit_unary(unary), + BoundPredicate::Binary(binary) => self.visit_binary(binary), + BoundPredicate::Set(set) => self.visit_set(set), + } + } + + /// Visit an AND predicate. + fn visit_and(&mut self, predicates: [&BoundPredicate; 2]) -> Result { + let mut results = Vec::with_capacity(predicates.len()); + for predicate in predicates { + let result = self.visit_predicate(predicate)?; + results.push(result); + } + self.and(results) + } + + /// Visit an OR predicate. + fn visit_or(&mut self, predicates: [&BoundPredicate; 2]) -> Result { + let mut results = Vec::with_capacity(predicates.len()); + for predicate in predicates { + let result = self.visit_predicate(predicate)?; + results.push(result); + } + self.or(results) + } + + /// Visit a NOT predicate. + fn visit_not(&mut self, predicate: [&BoundPredicate; 1]) -> Result { + let result = self.visit_predicate(predicate.first().unwrap())?; + self.not(result) + } + + /// Visit an always true predicate. + fn visit_always_true(&mut self) -> Result; + + /// Visit an always false predicate. + fn visit_always_false(&mut self) -> Result; + + /// Visit a unary predicate. + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result; + + /// Visit a binary predicate. + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result; + + /// Visit a set predicate. + fn visit_set(&mut self, predicate: &SetExpression) -> Result; + + /// Called after visiting predicates of AND. + fn and(&mut self, predicates: Vec) -> Result; + + /// Called after visiting predicates of OR. + fn or(&mut self, predicates: Vec) -> Result; + + /// Called after visiting predicates of NOT. + fn not(&mut self, predicate: Self::T) -> Result; + + /// Visit a bound reference. + fn bound_reference(&mut self, reference: &BoundReference) -> Result; } /// A post order arrow schema visitor. diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 67a46e2b1..6e33c9a49 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -32,7 +32,7 @@ use crate::spec::{Datum, SchemaRef}; use crate::{Error, ErrorKind}; /// Logical expression, such as `AND`, `OR`, `NOT`. -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub struct LogicalExpression { inputs: [Box; N], } @@ -79,7 +79,7 @@ where } /// Unary predicate, for example, `a IS NULL`. -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub struct UnaryExpression { /// Operator of this predicate, must be single operand operator. op: PredicateOperator, @@ -116,10 +116,20 @@ impl UnaryExpression { debug_assert!(op.is_unary()); Self { op, term } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } } /// Binary predicate, for example, `a > 10`. -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub struct BinaryExpression { /// Operator of this predicate, must be binary operator, such as `=`, `>`, `<`, etc. op: PredicateOperator, @@ -144,6 +154,21 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } + + /// Return the literal of this predicate. + pub fn literal(&self) -> &Datum { + &self.literal + } } impl Display for BinaryExpression { @@ -162,7 +187,7 @@ impl Bind for BinaryExpression { } /// Set predicates, for example, `a in (1, 2, 3)`. -#[derive(PartialEq)] +#[derive(PartialEq, Clone)] pub struct SetExpression { /// Operator of this predicate, must be set operator, such as `IN`, `NOT IN`, etc. op: PredicateOperator, @@ -187,6 +212,16 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } } impl Bind for SetExpression { @@ -548,7 +583,7 @@ impl Not for Predicate { } /// Bound predicate expression after binding to a schema. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum BoundPredicate { /// An expression always evaluates to true. AlwaysTrue, diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index 489161b46..5d7014664 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -18,6 +18,7 @@ //! Table scan api. use crate::arrow::ArrowReaderBuilder; +use crate::expr::{Bind, BoundPredicate, Predicate}; use crate::io::FileIO; use crate::spec::{DataContentType, ManifestEntryRef, SchemaRef, SnapshotRef, TableMetadataRef}; use crate::table::Table; @@ -32,6 +33,7 @@ pub struct TableScanBuilder<'a> { table: &'a Table, // Empty column names means to select all columns column_names: Vec, + predicates: Vec, snapshot_id: Option, batch_size: Option, } @@ -41,6 +43,7 @@ impl<'a> TableScanBuilder<'a> { Self { table, column_names: vec![], + predicates: vec![], snapshot_id: None, batch_size: None, } @@ -59,6 +62,12 @@ impl<'a> TableScanBuilder<'a> { self } + /// Add a predicate to the scan. The scan will only return rows that match the predicate. + pub fn filter(mut self, predicate: Predicate) -> Self { + self.predicates.push(predicate); + self + } + /// Select some columns of the table. pub fn select(mut self, column_names: impl IntoIterator) -> Self { self.column_names = column_names @@ -115,11 +124,17 @@ impl<'a> TableScanBuilder<'a> { } } + let mut bound_predicates = Vec::new(); + for predicate in self.predicates { + bound_predicates.push(predicate.bind(schema.clone(), true)?); + } + Ok(TableScan { snapshot, file_io: self.table.file_io().clone(), table_metadata: self.table.metadata_ref(), column_names: self.column_names, + bound_predicates, schema, batch_size: self.batch_size, }) @@ -134,6 +149,7 @@ pub struct TableScan { table_metadata: TableMetadataRef, file_io: FileIO, column_names: Vec, + bound_predicates: Vec, schema: SchemaRef, batch_size: Option, } @@ -191,6 +207,8 @@ impl TableScan { arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); } + arrow_reader_builder = arrow_reader_builder.with_predicates(self.bound_predicates.clone()); + arrow_reader_builder.build().read(self.plan_files().await?) } } @@ -216,9 +234,10 @@ impl FileScanTask { #[cfg(test)] mod tests { + use crate::expr::Reference; use crate::io::{FileIO, OutputFile}; use crate::spec::{ - DataContentType, DataFileBuilder, DataFileFormat, FormatVersion, Literal, Manifest, + DataContentType, DataFileBuilder, DataFileFormat, Datum, FormatVersion, Literal, Manifest, ManifestContentType, ManifestEntry, ManifestListWriter, ManifestMetadata, ManifestStatus, ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID, }; @@ -390,18 +409,39 @@ mod tests { // prepare data let schema = { - let fields = - vec![ - arrow_schema::Field::new("col", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "0".to_string(), - )])), - ]; + let fields = vec![ + arrow_schema::Field::new("x", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "1".to_string(), + )])), + arrow_schema::Field::new("y", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "2".to_string(), + )])), + arrow_schema::Field::new("z", arrow_schema::DataType::Int64, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "3".to_string(), + )])), + ]; Arc::new(arrow_schema::Schema::new(fields)) }; - let col = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; - let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap(); + let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + + let mut values = vec![2; 512]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + values.append(vec![5; 12].as_mut()); + + let col2 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + + let mut values = vec![3; 512]; + values.append(vec![4; 512].as_mut()); + + let col3 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + let to_write = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap(); // Write the Parquet files let props = WriterProperties::builder() @@ -531,9 +571,95 @@ mod tests { let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); - let col = batches[0].column_by_name("col").unwrap(); + let col = batches[0].column_by_name("x").unwrap(); + + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").less_than(Datum::long(3)); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 2); + } + + #[tokio::test] + async fn test_filter_on_arrow_gt_eq() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y >= 5 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than_or_equal_to(Datum::long(5)); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 12); + let col = batches[0].column_by_name("x").unwrap(); let int64_arr = col.as_any().downcast_ref::().unwrap(); assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 5); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_null(); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 0); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_not_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is not null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_not_null(); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 1024); } } diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 00f2e57d2..5de1da5e7 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -84,7 +84,7 @@ pub enum PrimitiveLiteral { /// /// By default, we decouple the type and value of a literal, so we can use avoid the cost of storing extra type info /// for each literal. But associate type with literal can be useful in some cases, for example, in unbound expression. -#[derive(Debug, PartialEq, Hash, Eq)] +#[derive(Debug, Clone, PartialEq, Hash, Eq)] pub struct Datum { r#type: PrimitiveType, literal: PrimitiveLiteral, @@ -673,6 +673,11 @@ impl Datum { )), } } + + /// Returns the literal of the datum. + pub fn literal(&self) -> &PrimitiveLiteral { + &self.literal + } } /// Values present in iceberg type