From 22af140effb74fd2c70a860eeb703040829e1e23 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 23 Mar 2024 21:33:45 -0700 Subject: [PATCH] feat: Convert predicate to arrow filter and push down to parquet reader --- Cargo.toml | 1 + crates/iceberg/Cargo.toml | 1 + crates/iceberg/src/arrow.rs | 400 +++++++++++++++++++++++++++ crates/iceberg/src/expr/predicate.rs | 35 +++ 4 files changed, 437 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 7da16e00d..3552f8a69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ apache-avro = "0.16" array-init = "2" arrow-arith = { version = "51" } arrow-array = { version = "51" } +arrow-ord = { version = "51" } arrow-schema = { version = "51" } async-stream = "0.3.5" async-trait = "0.1" diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 5aea856fe..31139552d 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -34,6 +34,7 @@ apache-avro = { workspace = true } array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } async-stream = { workspace = true } async-trait = { workspace = true } diff --git a/crates/iceberg/src/arrow.rs b/crates/iceberg/src/arrow.rs index 527fb1917..5af61c459 100644 --- a/crates/iceberg/src/arrow.rs +++ b/crates/iceberg/src/arrow.rs @@ -20,17 +20,27 @@ use async_stream::try_stream; use futures::stream::StreamExt; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; +use std::collections::HashMap; use crate::io::FileIO; use crate::scan::{ArrowRecordBatchStream, FileScanTask, FileScanTaskStream}; use crate::spec::SchemaRef; use crate::error::Result; +use crate::expr::{ + BinaryExpression, BoundPredicate, BoundReference, PredicateOperator, SetExpression, + UnaryExpression, +}; use crate::spec::{ ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, StructType, Type, }; use crate::{Error, ErrorKind}; +use arrow_arith::boolean::{and, is_not_null, is_null, not, or}; +use arrow_array::BooleanArray; +use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq}; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; +use parquet::arrow::arrow_reader::{ArrowPredicate, ArrowPredicateFn, RowFilter}; +use parquet::schema::types::{SchemaDescriptor, Type as ParquetType}; use std::sync::Arc; /// Builder to create ArrowReader @@ -38,6 +48,7 @@ pub struct ArrowReaderBuilder { batch_size: Option, file_io: FileIO, schema: SchemaRef, + predicates: Option>, } impl ArrowReaderBuilder { @@ -47,6 +58,7 @@ impl ArrowReaderBuilder { batch_size: None, file_io, schema, + predicates: None, } } @@ -57,12 +69,19 @@ impl ArrowReaderBuilder { self } + /// Sets the predicates to apply to the scan. + pub fn with_predicates(mut self, predicates: Vec) -> Self { + self.predicates = Some(predicates); + self + } + /// Build the ArrowReader. pub fn build(self) -> ArrowReader { ArrowReader { batch_size: self.batch_size, schema: self.schema, file_io: self.file_io, + predicates: self.predicates, } } } @@ -73,6 +92,7 @@ pub struct ArrowReader { #[allow(dead_code)] schema: SchemaRef, file_io: FileIO, + predicates: Option>, } impl ArrowReader { @@ -95,6 +115,13 @@ impl ArrowReader { .await? .with_projection(projection_mask); + let parquet_schema = batch_stream_builder.parquet_schema(); + let row_filter = self.get_row_filter(parquet_schema)?; + + if let Some(row_filter) = row_filter { + batch_stream_builder = batch_stream_builder.with_row_filter(row_filter); + } + if let Some(batch_size) = self.batch_size { batch_stream_builder = batch_stream_builder.with_batch_size(batch_size); } @@ -113,6 +140,379 @@ impl ArrowReader { // TODO: full implementation ProjectionMask::all() } + + fn get_row_filter(&self, parquet_schema: &SchemaDescriptor) -> Result> { + if let Some(predicates) = &self.predicates { + let field_id_map = self.build_field_id_map(parquet_schema)?; + + // Collect Parquet column indices from field ids + let column_indices = predicates + .iter() + .map(|predicate| { + let mut collector = CollectFieldIdVisitor { field_ids: vec![] }; + collector.visit_predicate(predicate).unwrap(); + collector + .field_ids + .iter() + .map(|field_id| { + field_id_map.get(field_id).cloned().ok_or_else(|| { + Error::new(ErrorKind::DataInvalid, "Field id not found in schema") + }) + }) + .collect::>>() + }) + .collect::>>()?; + + // Convert BoundPredicates to ArrowPredicates + let mut arrow_predicates = vec![]; + for (predicate, columns) in predicates.iter().zip(column_indices.iter()) { + let mut converter = PredicateConverter { + projection_mask: ProjectionMask::leaves(parquet_schema, columns.clone()), + parquet_schema, + column_map: &field_id_map, + }; + let arrow_predicate = converter.visit_predicate(predicate)?; + arrow_predicates.push(arrow_predicate); + } + Ok(Some(RowFilter::new(arrow_predicates))) + } else { + Ok(None) + } + } + + /// Build the map of field id to Parquet column index in the schema. + fn build_field_id_map(&self, parquet_schema: &SchemaDescriptor) -> Result> { + let mut column_map = HashMap::new(); + for (idx, field) in parquet_schema.columns().iter().enumerate() { + let field_type = field.self_type(); + match field_type { + ParquetType::PrimitiveType { basic_info, .. } => { + if !basic_info.has_id() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column {:?} in schema doesn't have field id", + field_type + ), + )); + } + column_map.insert(basic_info.id(), idx); + } + ParquetType::GroupType { .. } => { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Leave column in schema should be primitive type but got {:?}", + field_type + ), + )); + } + }; + } + + Ok(column_map) + } +} + +/// A visitor to collect field ids from bound predicates. +struct CollectFieldIdVisitor { + field_ids: Vec, +} + +impl BoundPredicateVisitor for CollectFieldIdVisitor { + type T = (); + type U = (); + + fn and(&mut self, _predicates: Vec) -> Result { + Ok(()) + } + + fn or(&mut self, _predicates: Vec) -> Result { + Ok(()) + } + + fn not(&mut self, _predicate: Self::T) -> Result { + Ok(()) + } + + fn visit_always_true(&mut self) -> Result { + Ok(()) + } + + fn visit_always_false(&mut self) -> Result { + Ok(()) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn visit_set(&mut self, predicate: &SetExpression) -> Result { + self.bound_reference(predicate.term())?; + Ok(()) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result { + self.field_ids.push(reference.field().id); + Ok(()) + } +} + +struct PredicateConverter<'a> { + pub projection_mask: ProjectionMask, + pub parquet_schema: &'a SchemaDescriptor, + pub column_map: &'a HashMap, +} + +impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { + type T = Box; + type U = usize; + + fn visit_always_true(&mut self) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![true; batch.num_rows()])), + ))) + } + + fn visit_always_false(&mut self) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + |batch| Ok(BooleanArray::from(vec![false; batch.num_rows()])), + ))) + } + + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::IsNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_null(column) + }, + ))), + PredicateOperator::NotNull => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let column = batch.column(term_index); + is_not_null(column) + }, + ))), + PredicateOperator::IsNan => { + todo!("IsNan is not supported yet") + } + PredicateOperator::NotNan => { + todo!("NotNan is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported unary operator: {op}"), + )), + } + } + + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result { + let term_index = self.bound_reference(predicate.term())?; + + match predicate.op() { + PredicateOperator::LessThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + lt(left, right) + }, + ))), + PredicateOperator::LessThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + lt_eq(left, right) + }, + ))), + PredicateOperator::GreaterThan => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + gt(right, left) + }, + ))), + PredicateOperator::GreaterThanOrEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + gt_eq(right, left) + }, + ))), + PredicateOperator::Eq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + eq(left, right) + }, + ))), + PredicateOperator::NotEq => Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = batch.column(term_index); + let right = batch.column(term_index); + neq(left, right) + }, + ))), + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported binary operator: {op}"), + )), + } + } + + fn visit_set(&mut self, predicate: &SetExpression) -> Result { + match predicate.op() { + PredicateOperator::In => { + todo!("In is not supported yet") + } + PredicateOperator::NotIn => { + todo!("NotIn is not supported yet") + } + op => Err(Error::new( + ErrorKind::DataInvalid, + format!("Unsupported set operator: {op}"), + )), + } + } + + fn and(&mut self, mut predicates: Vec) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + and(&left, &right) + }, + ))) + } + + fn or(&mut self, mut predicates: Vec) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let left = predicates.get_mut(0).unwrap().evaluate(batch.clone())?; + let right = predicates.get_mut(1).unwrap().evaluate(batch)?; + or(&left, &right) + }, + ))) + } + + fn not(&mut self, mut predicate: Self::T) -> Result { + Ok(Box::new(ArrowPredicateFn::new( + self.projection_mask.clone(), + move |batch| { + let evaluated = predicate.evaluate(batch.clone())?; + not(&evaluated) + }, + ))) + } + + fn bound_reference(&mut self, reference: &BoundReference) -> Result { + let column_idx = self.column_map.get(&reference.field().id).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field id {} not found in schema", reference.field().id), + ) + })?; + + let root_col_index = self.parquet_schema.get_column_root_idx(*column_idx); + + Ok(root_col_index) + } +} + +/// A visitor for bound predicates. +pub trait BoundPredicateVisitor { + /// Return type of this visitor on bound predicate. + type T; + + /// Return type of this visitor on bound reference. + type U; + + /// Visit a bound predicate. + fn visit_predicate(&mut self, predicate: &BoundPredicate) -> Result { + match predicate { + BoundPredicate::And(predicates) => self.visit_and(predicates.inputs()), + BoundPredicate::Or(predicates) => self.visit_or(predicates.inputs()), + BoundPredicate::Not(predicate) => self.visit_not(predicate.inputs()), + BoundPredicate::AlwaysTrue => self.visit_always_true(), + BoundPredicate::AlwaysFalse => self.visit_always_false(), + BoundPredicate::Unary(unary) => self.visit_unary(unary), + BoundPredicate::Binary(binary) => self.visit_binary(binary), + BoundPredicate::Set(set) => self.visit_set(set), + } + } + + /// Visit an AND predicate. + fn visit_and(&mut self, predicates: [&BoundPredicate; 2]) -> Result { + let mut results = Vec::with_capacity(predicates.len()); + for predicate in predicates { + let result = self.visit_predicate(predicate)?; + results.push(result); + } + self.and(results) + } + + /// Visit an OR predicate. + fn visit_or(&mut self, predicates: [&BoundPredicate; 2]) -> Result { + let mut results = Vec::with_capacity(predicates.len()); + for predicate in predicates { + let result = self.visit_predicate(predicate)?; + results.push(result); + } + self.or(results) + } + + /// Visit a NOT predicate. + fn visit_not(&mut self, predicate: [&BoundPredicate; 1]) -> Result { + let result = self.visit_predicate(predicate.first().unwrap())?; + self.not(result) + } + + /// Visit an always true predicate. + fn visit_always_true(&mut self) -> Result; + + /// Visit an always false predicate. + fn visit_always_false(&mut self) -> Result; + + /// Visit a unary predicate. + fn visit_unary(&mut self, predicate: &UnaryExpression) -> Result; + + /// Visit a binary predicate. + fn visit_binary(&mut self, predicate: &BinaryExpression) -> Result; + + /// Visit a set predicate. + fn visit_set(&mut self, predicate: &SetExpression) -> Result; + + /// Called after visiting predicates of AND. + fn and(&mut self, predicates: Vec) -> Result; + + /// Called after visiting predicates of OR. + fn or(&mut self, predicates: Vec) -> Result; + + /// Called after visiting predicates of NOT. + fn not(&mut self, predicate: Self::T) -> Result; + + /// Visit a bound reference. + fn bound_reference(&mut self, reference: &BoundReference) -> Result; } /// A post order arrow schema visitor. diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 67a46e2b1..2bd29992c 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -116,6 +116,16 @@ impl UnaryExpression { debug_assert!(op.is_unary()); Self { op, term } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } } /// Binary predicate, for example, `a > 10`. @@ -144,6 +154,21 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } + + /// Return the literal of this predicate. + pub fn literal(&self) -> &Datum { + &self.literal + } } impl Display for BinaryExpression { @@ -187,6 +212,16 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + + /// Return the term of this predicate. + pub fn term(&self) -> &T { + &self.term + } + + /// Return the operator of this predicate. + pub fn op(&self) -> &PredicateOperator { + &self.op + } } impl Bind for SetExpression {