diff --git a/Cargo.toml b/Cargo.toml index 57c343611..125cd0d06 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,7 @@ apache-avro = "0.16" array-init = "2" arrow-arith = { version = "51" } arrow-array = { version = "51" } +arrow-ord = { version = "51" } arrow-schema = { version = "51" } arrow-select = { version = "51" } async-stream = "0.3.5" diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 46f167b7a..95e007846 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -34,6 +34,7 @@ apache-avro = { workspace = true } array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } async-stream = { workspace = true } diff --git a/crates/iceberg/src/arrow/reader.rs b/crates/iceberg/src/arrow/reader.rs index fe5efaca1..391239cf6 100644 --- a/crates/iceberg/src/arrow/reader.rs +++ b/crates/iceberg/src/arrow/reader.rs @@ -17,25 +17,33 @@ //! Parquet file data reader -use arrow_schema::SchemaRef as ArrowSchemaRef; +use crate::error::Result; +use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; +use arrow_array::{ArrayRef, BooleanArray, RecordBatch}; +use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; +use arrow_schema::{ArrowError, DataType, SchemaRef as ArrowSchemaRef}; use async_stream::try_stream; use bytes::Bytes; +use fnv::FnvHashSet; use futures::future::BoxFuture; use futures::stream::StreamExt; use futures::{try_join, TryFutureExt}; +use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter}; use parquet::arrow::async_reader::{AsyncFileReader, MetadataLoader}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask, PARQUET_FIELD_ID_META_KEY}; use parquet::file::metadata::ParquetMetaData; -use parquet::schema::types::SchemaDescriptor; -use std::collections::HashMap; +use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; +use std::collections::{HashMap, HashSet}; use std::ops::Range; use std::str::FromStr; use std::sync::Arc; -use crate::arrow::arrow_schema_to_schema; +use crate::arrow::{arrow_schema_to_schema, get_arrow_datum}; +use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; +use crate::expr::{BoundPredicate, BoundReference}; use crate::io::{FileIO, FileMetadata, FileRead}; use crate::scan::{ArrowRecordBatchStream, FileScanTaskStream}; -use crate::spec::SchemaRef; +use crate::spec::{Datum, SchemaRef}; use crate::{Error, ErrorKind}; /// Builder to create ArrowReader @@ -44,6 +52,7 @@ pub struct ArrowReaderBuilder { field_ids: Vec, file_io: FileIO, schema: SchemaRef, + predicates: Option, } impl ArrowReaderBuilder { @@ -54,6 +63,7 @@ impl ArrowReaderBuilder { field_ids: vec![], file_io, schema, + predicates: None, } } @@ -70,6 +80,12 @@ impl ArrowReaderBuilder { self } + /// Sets the predicates to apply to the scan. + pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self { + self.predicates = Some(predicates); + self + } + /// Build the ArrowReader. pub fn build(self) -> ArrowReader { ArrowReader { @@ -77,6 +93,7 @@ impl ArrowReaderBuilder { field_ids: self.field_ids, schema: self.schema, file_io: self.file_io, + predicates: self.predicates, } } } @@ -88,6 +105,7 @@ pub struct ArrowReader { #[allow(dead_code)] schema: SchemaRef, file_io: FileIO, + predicates: Option, } impl ArrowReader { @@ -96,6 +114,14 @@ impl ArrowReader { pub fn read(self, mut tasks: FileScanTaskStream) -> crate::Result { let file_io = self.file_io.clone(); + // Collect Parquet column indices from field ids + let mut collector = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + if let Some(predicates) = &self.predicates { + visit(&mut collector, predicates)?; + } + Ok(try_stream! { while let Some(Ok(task)) = tasks.next().await { let parquet_file = file_io @@ -111,6 +137,13 @@ impl ArrowReader { let projection_mask = self.get_arrow_projection_mask(parquet_schema, arrow_schema)?; batch_stream_builder = batch_stream_builder.with_projection(projection_mask); + let parquet_schema = batch_stream_builder.parquet_schema(); + let row_filter = self.get_row_filter(parquet_schema, &collector)?; + + if let Some(row_filter) = row_filter { + batch_stream_builder = batch_stream_builder.with_row_filter(row_filter); + } + if let Some(batch_size) = self.batch_size { batch_stream_builder = batch_stream_builder.with_batch_size(batch_size); } @@ -193,6 +226,558 @@ impl ArrowReader { Ok(ProjectionMask::leaves(parquet_schema, indices)) } } + + fn get_row_filter( + &self, + parquet_schema: &SchemaDescriptor, + collector: &CollectFieldIdVisitor, + ) -> Result> { + if let Some(predicates) = &self.predicates { + let field_id_map = build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids. + // If the field id is not found in Parquet schema, it will be ignored due to schema evolution. + let mut column_indices = collector + .field_ids + .iter() + .filter_map(|field_id| field_id_map.get(field_id).cloned()) + .collect::>(); + + column_indices.sort(); + + // The converter that converts `BoundPredicates` to `ArrowPredicates` + let mut converter = PredicateConverter { + parquet_schema, + column_map: &field_id_map, + column_indices: &column_indices, + }; + + // After collecting required leaf column indices used in the predicate, + // creates the projection mask for the Arrow predicates. + let projection_mask = ProjectionMask::leaves(parquet_schema, column_indices.clone()); + let predicate_func = visit(&mut converter, predicates)?; + let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate_func); + Ok(Some(RowFilter::new(vec![Box::new(arrow_predicate)]))) + } else { + Ok(None) + } + } +} + +/// Build the map of field id to Parquet column index in the schema. +fn build_field_id_map(parquet_schema: &SchemaDescriptor) -> Result> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column idx: {}, name: {}, type {:?} in schema doesn't have field id", + idx, + basic_info.name(), + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: HashSet, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + + fn always_true(&mut self) -> Result<()> { + Ok(()) + } + + fn always_false(&mut self) -> Result<()> { + Ok(()) + } + + fn and(&mut self, _lhs: (), _rhs: ()) -> Result<()> { + Ok(()) + } + + fn or(&mut self, _lhs: (), _rhs: ()) -> Result<()> { + Ok(()) + } + + fn not(&mut self, _inner: ()) -> Result<()> { + Ok(()) + } + + fn is_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_null(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn is_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_nan(&mut self, reference: &BoundReference, _predicate: &BoundPredicate) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn less_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn greater_than( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_eq( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_starts_with( + &mut self, + reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn r#in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } + + fn not_in( + &mut self, + reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result<()> { + self.field_ids.insert(reference.field().id); + Ok(()) + } +} + +/// A visitor to convert Iceberg bound predicates to Arrow predicates. +struct PredicateConverter<'a> { + /// The Parquet schema descriptor. + pub parquet_schema: &'a SchemaDescriptor, + /// The map between field id and leaf column index in Parquet schema. + pub column_map: &'a HashMap, + /// The required column indices in Parquet schema for the predicates. + pub column_indices: &'a Vec, +} + +impl PredicateConverter<'_> { + /// When visiting a bound reference, we return index of the leaf column in the + /// required column indices which is used to project the column in the record batch. + /// Return None if the field id is not found in the column map, which is possible + /// due to schema evolution. + fn bound_reference(&mut self, reference: &BoundReference) -> Result> { + // The leaf column's index in Parquet schema. + if let Some(column_idx) = self.column_map.get(&reference.field().id) { + if self.parquet_schema.get_column_root_idx(*column_idx) != *column_idx { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column `{}` in predicates isn't a root column in Parquet schema.", + reference.field().name + ), + )); + } + + // The leaf column's index in the required column indices. + let index = self + .column_indices + .iter() + .position(|&idx| idx == *column_idx).ok_or(Error::new(ErrorKind::DataInvalid, format!( + "Leave column `{}` in predicates cannot be found in the required column indices.", + reference.field().name + )))?; + + Ok(Some(index)) + } else { + Ok(None) + } + } + + /// Build an Arrow predicate that always returns true. + fn build_always_true(&self) -> Result> { + Ok(Box::new(|batch| { + Ok(BooleanArray::from(vec![true; batch.num_rows()])) + })) + } + + /// Build an Arrow predicate that always returns false. + fn build_always_false(&self) -> Result> { + Ok(Box::new(|batch| { + Ok(BooleanArray::from(vec![false; batch.num_rows()])) + })) + } +} + +/// Gets the leaf column from the record batch for the required column index. Only +/// supports top-level columns for now. +fn project_column( + batch: &RecordBatch, + column_idx: usize, +) -> std::result::Result { + let column = batch.column(column_idx); + + match column.data_type() { + DataType::Struct(_) => Err(ArrowError::SchemaError( + "Does not support struct column yet.".to_string(), + )), + _ => Ok(column.clone()), + } +} + +type PredicateResult = + dyn FnMut(RecordBatch) -> std::result::Result + Send + 'static; + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box; + + fn always_true(&mut self) -> Result> { + self.build_always_true() + } + + fn always_false(&mut self) -> Result> { + self.build_always_false() + } + + fn and( + &mut self, + mut lhs: Box, + mut rhs: Box, + ) -> Result> { + Ok(Box::new(move |batch| { + let left = lhs(batch.clone())?; + let right = rhs(batch)?; + and(&left, &right) + })) + } + + fn or( + &mut self, + mut lhs: Box, + mut rhs: Box, + ) -> Result> { + Ok(Box::new(move |batch| { + let left = lhs(batch.clone())?; + let right = rhs(batch)?; + or(&left, &right) + })) + } + + fn not(&mut self, mut inner: Box) -> Result> { + Ok(Box::new(move |batch| { + let pred_ret = inner(batch)?; + not(&pred_ret) + })) + } + + fn is_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + is_null(&column) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn not_null( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + Ok(Box::new(move |batch| { + let column = project_column(&batch, idx)?; + is_not_null(&column) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn is_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if self.bound_reference(reference)?.is_some() { + self.build_always_true() + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_nan( + &mut self, + reference: &BoundReference, + _predicate: &BoundPredicate, + ) -> Result> { + if self.bound_reference(reference)?.is_some() { + self.build_always_false() + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn less_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + lt(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn less_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + lt_eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } + } + + fn greater_than( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + gt(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn greater_than_or_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + gt_eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + eq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn not_eq( + &mut self, + reference: &BoundReference, + literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + neq(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } + } + + fn starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + // TODO: Implement starts_with + self.build_always_true() + } + + fn not_starts_with( + &mut self, + _reference: &BoundReference, + _literal: &Datum, + _predicate: &BoundPredicate, + ) -> Result> { + // TODO: Implement not_starts_with + self.build_always_true() + } + + fn r#in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result> { + // TODO: Implement in + self.build_always_true() + } + + fn not_in( + &mut self, + _reference: &BoundReference, + _literals: &FnvHashSet, + _predicate: &BoundPredicate, + ) -> Result> { + // TODO: Implement not_in + self.build_always_true() + } } /// ArrowFileReader is a wrapper around a FileRead that impls parquets AsyncFileReader. @@ -234,3 +819,86 @@ impl AsyncFileReader for ArrowFileReader { }) } } + +#[cfg(test)] +mod tests { + use crate::arrow::reader::CollectFieldIdVisitor; + use crate::expr::visitors::bound_predicate_visitor::visit; + use crate::expr::{Bind, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + use std::collections::HashSet; + use std::sync::Arc; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + NestedField::optional(4, "qux", Type::Primitive(PrimitiveType::Float)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_collect_field_id() { + let schema = table_schema_simple(); + let expr = Reference::new("qux").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + + assert_eq!(visitor.field_ids, expected); + } + + #[test] + fn test_collect_field_id_with_and() { + let schema = table_schema_simple(); + let expr = Reference::new("qux") + .is_null() + .and(Reference::new("baz").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + expected.insert(3); + + assert_eq!(visitor.field_ids, expected); + } + + #[test] + fn test_collect_field_id_with_or() { + let schema = table_schema_simple(); + let expr = Reference::new("qux") + .is_null() + .or(Reference::new("baz").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + + let mut visitor = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + visit(&mut visitor, &bound_expr).unwrap(); + + let mut expected = HashSet::default(); + expected.insert(4_i32); + expected.insert(3); + + assert_eq!(visitor.field_ids, expected); + } +} diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index c7e870096..172d4bb79 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -19,12 +19,16 @@ use crate::error::Result; use crate::spec::{ - ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, SchemaVisitor, - StructType, Type, + Datum, ListType, MapType, NestedField, NestedFieldRef, PrimitiveLiteral, PrimitiveType, Schema, + SchemaVisitor, StructType, Type, }; use crate::{Error, ErrorKind}; use arrow_array::types::{validate_decimal_precision_and_scale, Decimal128Type}; +use arrow_array::{ + BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, +}; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; +use bitvec::macros::internal::funty::Fundamental; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use rust_decimal::prelude::ToPrimitive; use std::collections::HashMap; @@ -593,6 +597,24 @@ pub fn schema_to_arrow_schema(schema: &crate::spec::Schema) -> crate::Result Result> { + match datum.literal() { + PrimitiveLiteral::Boolean(value) => Ok(Box::new(BooleanArray::new_scalar(*value))), + PrimitiveLiteral::Int(value) => Ok(Box::new(Int32Array::new_scalar(*value))), + PrimitiveLiteral::Long(value) => Ok(Box::new(Int64Array::new_scalar(*value))), + PrimitiveLiteral::Float(value) => Ok(Box::new(Float32Array::new_scalar(value.as_f32()))), + PrimitiveLiteral::Double(value) => Ok(Box::new(Float64Array::new_scalar(value.as_f64()))), + l => Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Converting datum from type {:?} to arrow not supported yet.", + l + ), + )), + } +} + impl TryFrom<&ArrowSchema> for crate::spec::Schema { type Error = Error; diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 1457d5a17..158ab135b 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -116,10 +116,13 @@ impl UnaryExpression { debug_assert!(op.is_unary()); Self { op, term } } + + /// Return the operator of this predicate. pub(crate) fn op(&self) -> PredicateOperator { self.op } + /// Return the term of this predicate. pub(crate) fn term(&self) -> &T { &self.term } @@ -155,10 +158,13 @@ impl BinaryExpression { pub(crate) fn op(&self) -> PredicateOperator { self.op } + + /// Return the literal of this predicate. pub(crate) fn literal(&self) -> &Datum { &self.literal } + /// Return the term of this predicate. pub(crate) fn term(&self) -> &T { &self.term } @@ -210,13 +216,16 @@ impl SetExpression { Self { op, term, literals } } + /// Return the operator of this predicate. pub(crate) fn op(&self) -> PredicateOperator { self.op } + pub(crate) fn literals(&self) -> &FnvHashSet { &self.literals } + /// Return the term of this predicate. pub(crate) fn term(&self) -> &T { &self.term } diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index c2a5e1b2d..906bf96a4 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -46,6 +46,7 @@ pub struct TableScanBuilder<'a> { table: &'a Table, // Empty column names means to select all columns column_names: Vec, + predicates: Option, snapshot_id: Option, batch_size: Option, case_sensitive: bool, @@ -57,6 +58,7 @@ impl<'a> TableScanBuilder<'a> { Self { table, column_names: vec![], + predicates: None, snapshot_id: None, batch_size: None, case_sensitive: true, @@ -91,6 +93,12 @@ impl<'a> TableScanBuilder<'a> { self } + /// Add a predicate to the scan. The scan will only return rows that match the predicate. + pub fn filter(mut self, predicate: Predicate) -> Self { + self.predicates = Some(predicate); + self + } + /// Select some columns of the table. pub fn select(mut self, column_names: impl IntoIterator) -> Self { self.column_names = column_names @@ -150,11 +158,18 @@ impl<'a> TableScanBuilder<'a> { } } + let bound_predicates = if let Some(ref predicates) = self.predicates { + Some(predicates.bind(schema.clone(), true)?) + } else { + None + }; + Ok(TableScan { snapshot, file_io: self.table.file_io().clone(), table_metadata: self.table.metadata_ref(), column_names: self.column_names, + bound_predicates, schema, batch_size: self.batch_size, case_sensitive: self.case_sensitive, @@ -171,6 +186,7 @@ pub struct TableScan { table_metadata: TableMetadataRef, file_io: FileIO, column_names: Vec, + bound_predicates: Option, schema: SchemaRef, batch_size: Option, case_sensitive: bool, @@ -300,6 +316,10 @@ impl TableScan { arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); } + if let Some(ref bound_predicates) = self.bound_predicates { + arrow_reader_builder = arrow_reader_builder.with_predicates(bound_predicates.clone()); + } + arrow_reader_builder.build().read(self.plan_files().await?) } @@ -449,9 +469,10 @@ impl FileScanTask { #[cfg(test)] mod tests { + use crate::expr::Reference; use crate::io::{FileIO, OutputFile}; use crate::spec::{ - DataContentType, DataFileBuilder, DataFileFormat, FormatVersion, Literal, Manifest, + DataContentType, DataFileBuilder, DataFileFormat, Datum, FormatVersion, Literal, Manifest, ManifestContentType, ManifestEntry, ManifestListWriter, ManifestMetadata, ManifestStatus, ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID, }; @@ -642,9 +663,23 @@ mod tests { ]; Arc::new(arrow_schema::Schema::new(fields)) }; + // 3 columns: + // x: [1, 1, 1, 1, ...] let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; - let col2 = Arc::new(Int64Array::from_iter_values(vec![2; 1024])) as ArrayRef; - let col3 = Arc::new(Int64Array::from_iter_values(vec![3; 1024])) as ArrayRef; + + let mut values = vec![2; 512]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + values.append(vec![5; 12].as_mut()); + + // y: [2, 2, 2, 2, ..., 3, 3, 3, 3, ..., 4, 4, 4, 4, ..., 5, 5, 5, 5] + let col2 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + + let mut values = vec![3; 512]; + values.append(vec![4; 512].as_mut()); + + // z: [3, 3, 3, 3, ..., 4, 4, 4, 4] + let col3 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; let to_write = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap(); // Write the Parquet files @@ -803,4 +838,161 @@ mod tests { let int64_arr = col2.as_any().downcast_ref::().unwrap(); assert_eq!(int64_arr.value(0), 3); } + + #[tokio::test] + async fn test_filter_on_arrow_lt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 3 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").less_than(Datum::long(3)); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 2); + } + + #[tokio::test] + async fn test_filter_on_arrow_gt_eq() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y >= 5 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").greater_than_or_equal_to(Datum::long(5)); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 12); + + let col = batches[0].column_by_name("x").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 1); + + let col = batches[0].column_by_name("y").unwrap(); + let int64_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_arr.value(0), 5); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_null(); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches.len(), 0); + } + + #[tokio::test] + async fn test_filter_on_arrow_is_not_null() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y is not null + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y").is_not_null(); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 1024); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt_and_gt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 5 AND z >= 4 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y") + .less_than(Datum::long(5)) + .and(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 500); + + let col = batches[0].column_by_name("x").unwrap(); + let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 500])) as ArrayRef; + assert_eq!(col, &expected_x); + + let col = batches[0].column_by_name("y").unwrap(); + let mut values = vec![]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + let expected_y = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_y); + + let col = batches[0].column_by_name("z").unwrap(); + let expected_z = Arc::new(Int64Array::from_iter_values(vec![4; 500])) as ArrayRef; + assert_eq!(col, &expected_z); + } + + #[tokio::test] + async fn test_filter_on_arrow_lt_or_gt() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: y < 5 AND z >= 4 + let mut builder = fixture.table.scan(); + let predicate = Reference::new("y") + .less_than(Datum::long(5)) + .or(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); + builder = builder.filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + assert_eq!(batches[0].num_rows(), 1024); + + let col = batches[0].column_by_name("x").unwrap(); + let expected_x = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + assert_eq!(col, &expected_x); + + let col = batches[0].column_by_name("y").unwrap(); + let mut values = vec![2; 512]; + values.append(vec![3; 200].as_mut()); + values.append(vec![4; 300].as_mut()); + values.append(vec![5; 12].as_mut()); + let expected_y = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_y); + + let col = batches[0].column_by_name("z").unwrap(); + let mut values = vec![3; 512]; + values.append(vec![4; 512].as_mut()); + let expected_z = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; + assert_eq!(col, &expected_z); + } }