diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index a611eb5e6..76a902083 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -317,7 +317,8 @@ fn evaluate_expression( Equal => |l, r| eq(l, r).map(wrap_comparison_result), NotEqual => |l, r| neq(l, r).map(wrap_comparison_result), Distinct => |l, r| distinct(l, r).map(wrap_comparison_result), - _ => return Err(Error::generic("Invalid expression given")), + // NOTE: [Not]In was already covered above + In | NotIn => return Err(Error::generic("Invalid expression given")), }; eval(&left_arr, &right_arr).map_err(Error::generic_err) diff --git a/kernel/src/engine/parquet_row_group_skipping.rs b/kernel/src/engine/parquet_row_group_skipping.rs index d57e5dccc..e2e586480 100644 --- a/kernel/src/engine/parquet_row_group_skipping.rs +++ b/kernel/src/engine/parquet_row_group_skipping.rs @@ -1,5 +1,7 @@ //! An implementation of parquet row group skipping using data skipping predicates over footer stats. -use crate::engine::parquet_stats_skipping::ParquetStatsSkippingFilter; +use crate::engine::parquet_stats_skipping::{ + ParquetStatsProvider, ParquetStatsSkippingFilter as _, +}; use crate::expressions::{ColumnName, Expression, Scalar}; use crate::schema::{DataType, PrimitiveType}; use chrono::{DateTime, Days}; @@ -55,7 +57,7 @@ impl<'a> RowGroupFilter<'a> { /// Applies a filtering predicate to a row group. Return value false means to skip it. fn apply(row_group: &'a RowGroupMetaData, predicate: &Expression) -> bool { - RowGroupFilter::new(row_group, predicate).apply_sql_where(predicate) != Some(false) + RowGroupFilter::new(row_group, predicate).eval_sql_where(predicate) != Some(false) } /// Returns `None` if the column doesn't exist and `Some(None)` if the column has no stats. @@ -87,13 +89,13 @@ impl<'a> RowGroupFilter<'a> { } } -impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> { +impl<'a> ParquetStatsProvider for RowGroupFilter<'a> { // Extracts a stat value, converting from its physical type to the requested logical type. // // NOTE: This code is highly redundant with [`get_max_stat_value`] below, but parquet // ValueStatistics requires T to impl a private trait, so we can't factor out any kind of // helper method. And macros are hard enough to read that it's not worth defining one. - fn get_min_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option { + fn get_parquet_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option { use PrimitiveType::*; let value = match (data_type.as_primitive_opt()?, self.get_stats(col)??) { (String, Statistics::ByteArray(s)) => s.min_opt()?.as_utf8().ok()?.into(), @@ -135,7 +137,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> { Some(value) } - fn get_max_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option { + fn get_parquet_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option { use PrimitiveType::*; let value = match (data_type.as_primitive_opt()?, self.get_stats(col)??) { (String, Statistics::ByteArray(s)) => s.max_opt()?.as_utf8().ok()?.into(), @@ -177,7 +179,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> { Some(value) } - fn get_nullcount_stat_value(&self, col: &ColumnName) -> Option { + fn get_parquet_nullcount_stat(&self, col: &ColumnName) -> Option { // NOTE: Stats for any given column are optional, which may produce a NULL nullcount. But if // the column itself is missing, then we know all values are implied to be NULL. // @@ -187,7 +189,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> { // physical name mapping has been performed. Because we currently lack both the // validation and the name mapping support, we must disable this optimization for the // time being. See https://github.com/delta-incubator/delta-kernel-rs/issues/434. - return Some(self.get_rowcount_stat_value()).filter(|_| false); + return Some(self.get_parquet_rowcount_stat()).filter(|_| false); }; // WARNING: [`Statistics::null_count_opt`] returns Some(0) when the underlying stat is @@ -210,7 +212,7 @@ impl<'a> ParquetStatsSkippingFilter for RowGroupFilter<'a> { Some(nullcount? as i64) } - fn get_rowcount_stat_value(&self) -> i64 { + fn get_parquet_rowcount_stat(&self) -> i64 { self.row_group.num_rows() } } diff --git a/kernel/src/engine/parquet_row_group_skipping/tests.rs b/kernel/src/engine/parquet_row_group_skipping/tests.rs index c34eaa0a0..39a9c2ab5 100644 --- a/kernel/src/engine/parquet_row_group_skipping/tests.rs +++ b/kernel/src/engine/parquet_row_group_skipping/tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::predicates::DataSkippingPredicateEvaluator as _; use crate::expressions::{column_name, column_expr}; use crate::Expression; use parquet::arrow::arrow_reader::ArrowReaderMetadata; @@ -58,26 +59,20 @@ fn test_get_stat_values() { ]); let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns); - assert_eq!(filter.get_rowcount_stat_value(), 5); + assert_eq!(filter.get_rowcount_stat(), Some(5)); // Only the BOOL column has any nulls - assert_eq!( - filter.get_nullcount_stat_value(&column_name!("bool")), - Some(3) - ); - assert_eq!( - filter.get_nullcount_stat_value(&column_name!("varlen.utf8")), - Some(0) - ); + assert_eq!(filter.get_nullcount_stat(&column_name!("bool")), Some(3)); + assert_eq!(filter.get_nullcount_stat(&column_name!("varlen.utf8")), Some(0)); assert_eq!( - filter.get_min_stat_value(&column_name!("varlen.utf8"), &DataType::STRING), + filter.get_min_stat(&column_name!("varlen.utf8"), &DataType::STRING), Some("a".into()) ); // CHEAT: Interpret the decimal128 column's fixed-length binary as a string assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal128"), &DataType::STRING ), @@ -85,33 +80,33 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value(&column_name!("numeric.ints.int64"), &DataType::LONG), + filter.get_min_stat(&column_name!("numeric.ints.int64"), &DataType::LONG), Some(1000000000i64.into()) ); // type widening! assert_eq!( - filter.get_min_stat_value(&column_name!("numeric.ints.int32"), &DataType::LONG), + filter.get_min_stat(&column_name!("numeric.ints.int32"), &DataType::LONG), Some(1000000i64.into()) ); assert_eq!( - filter.get_min_stat_value(&column_name!("numeric.ints.int32"), &DataType::INTEGER), + filter.get_min_stat(&column_name!("numeric.ints.int32"), &DataType::INTEGER), Some(1000000i32.into()) ); assert_eq!( - filter.get_min_stat_value(&column_name!("numeric.ints.int16"), &DataType::SHORT), + filter.get_min_stat(&column_name!("numeric.ints.int16"), &DataType::SHORT), Some(1000i16.into()) ); assert_eq!( - filter.get_min_stat_value(&column_name!("numeric.ints.int8"), &DataType::BYTE), + filter.get_min_stat(&column_name!("numeric.ints.int8"), &DataType::BYTE), Some(0i8.into()) ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.floats.float64"), &DataType::DOUBLE ), @@ -120,7 +115,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.floats.float32"), &DataType::DOUBLE ), @@ -128,7 +123,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.floats.float32"), &DataType::FLOAT ), @@ -136,18 +131,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value(&column_name!("bool"), &DataType::BOOLEAN), + filter.get_min_stat(&column_name!("bool"), &DataType::BOOLEAN), Some(false.into()) ); assert_eq!( - filter.get_min_stat_value(&column_name!("varlen.binary"), &DataType::BINARY), + filter.get_min_stat(&column_name!("varlen.binary"), &DataType::BINARY), Some([].as_slice().into()) ); // CHEAT: Interpret the decimal128 column's fixed-len array as binary assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal128"), &DataType::BINARY ), @@ -159,7 +154,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(8, 3).unwrap() ), @@ -167,7 +162,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal64"), &DataType::decimal(16, 3).unwrap() ), @@ -176,7 +171,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(16, 3).unwrap() ), @@ -184,7 +179,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal128"), &DataType::decimal(32, 3).unwrap() ), @@ -193,7 +188,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal64"), &DataType::decimal(32, 3).unwrap() ), @@ -202,7 +197,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(32, 3).unwrap() ), @@ -210,18 +205,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value(&column_name!("chrono.date32"), &DataType::DATE), + filter.get_min_stat(&column_name!("chrono.date32"), &DataType::DATE), Some(PrimitiveType::Date.parse_scalar("1971-01-01").unwrap()) ); assert_eq!( - filter.get_min_stat_value(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP), + filter.get_min_stat(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP), None // Timestamp defaults to 96-bit, which doesn't get stats ); // CHEAT: Interpret the timestamp_ntz column as a normal timestamp assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP ), @@ -233,7 +228,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP_NTZ ), @@ -246,7 +241,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_min_stat_value( + filter.get_min_stat( &column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ ), @@ -258,13 +253,13 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value(&column_name!("varlen.utf8"), &DataType::STRING), + filter.get_max_stat(&column_name!("varlen.utf8"), &DataType::STRING), Some("e".into()) ); // CHEAT: Interpret the decimal128 column's fixed-length binary as a string assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal128"), &DataType::STRING ), @@ -272,33 +267,33 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value(&column_name!("numeric.ints.int64"), &DataType::LONG), + filter.get_max_stat(&column_name!("numeric.ints.int64"), &DataType::LONG), Some(1000000004i64.into()) ); // type widening! assert_eq!( - filter.get_max_stat_value(&column_name!("numeric.ints.int32"), &DataType::LONG), + filter.get_max_stat(&column_name!("numeric.ints.int32"), &DataType::LONG), Some(1000004i64.into()) ); assert_eq!( - filter.get_max_stat_value(&column_name!("numeric.ints.int32"), &DataType::INTEGER), + filter.get_max_stat(&column_name!("numeric.ints.int32"), &DataType::INTEGER), Some(1000004.into()) ); assert_eq!( - filter.get_max_stat_value(&column_name!("numeric.ints.int16"), &DataType::SHORT), + filter.get_max_stat(&column_name!("numeric.ints.int16"), &DataType::SHORT), Some(1004i16.into()) ); assert_eq!( - filter.get_max_stat_value(&column_name!("numeric.ints.int8"), &DataType::BYTE), + filter.get_max_stat(&column_name!("numeric.ints.int8"), &DataType::BYTE), Some(4i8.into()) ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.floats.float64"), &DataType::DOUBLE ), @@ -307,7 +302,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.floats.float32"), &DataType::DOUBLE ), @@ -315,7 +310,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.floats.float32"), &DataType::FLOAT ), @@ -323,18 +318,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value(&column_name!("bool"), &DataType::BOOLEAN), + filter.get_max_stat(&column_name!("bool"), &DataType::BOOLEAN), Some(true.into()) ); assert_eq!( - filter.get_max_stat_value(&column_name!("varlen.binary"), &DataType::BINARY), + filter.get_max_stat(&column_name!("varlen.binary"), &DataType::BINARY), Some([0, 0, 0, 0].as_slice().into()) ); // CHEAT: Interpret the decimal128 columns' fixed-len array as binary assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal128"), &DataType::BINARY ), @@ -346,7 +341,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(8, 3).unwrap() ), @@ -354,7 +349,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal64"), &DataType::decimal(16, 3).unwrap() ), @@ -363,7 +358,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(16, 3).unwrap() ), @@ -371,7 +366,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal128"), &DataType::decimal(32, 3).unwrap() ), @@ -380,7 +375,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal64"), &DataType::decimal(32, 3).unwrap() ), @@ -389,7 +384,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("numeric.decimals.decimal32"), &DataType::decimal(32, 3).unwrap() ), @@ -397,18 +392,18 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value(&column_name!("chrono.date32"), &DataType::DATE), + filter.get_max_stat(&column_name!("chrono.date32"), &DataType::DATE), Some(PrimitiveType::Date.parse_scalar("1971-01-05").unwrap()) ); assert_eq!( - filter.get_max_stat_value(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP), + filter.get_max_stat(&column_name!("chrono.timestamp"), &DataType::TIMESTAMP), None // Timestamp defaults to 96-bit, which doesn't get stats ); // CHEAT: Interpret the timestamp_ntz column as a normal timestamp assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP ), @@ -420,7 +415,7 @@ fn test_get_stat_values() { ); assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("chrono.timestamp_ntz"), &DataType::TIMESTAMP_NTZ ), @@ -433,7 +428,7 @@ fn test_get_stat_values() { // type widening! assert_eq!( - filter.get_max_stat_value( + filter.get_max_stat( &column_name!("chrono.date32"), &DataType::TIMESTAMP_NTZ ), diff --git a/kernel/src/engine/parquet_stats_skipping.rs b/kernel/src/engine/parquet_stats_skipping.rs index bfe2eef36..c8c2fdfe2 100644 --- a/kernel/src/engine/parquet_stats_skipping.rs +++ b/kernel/src/engine/parquet_stats_skipping.rs @@ -1,14 +1,101 @@ //! An implementation of data skipping that leverages parquet stats from the file footer. use crate::expressions::{ - BinaryOperator, ColumnName, Expression, Scalar, UnaryOperator, VariadicOperator, + BinaryOperator, ColumnName, Expression as Expr, Scalar, UnaryOperator, VariadicOperator, +}; +use crate::predicates::{ + DataSkippingPredicateEvaluator, PredicateEvaluator, PredicateEvaluatorDefaults, }; use crate::schema::DataType; use std::cmp::Ordering; -use tracing::info; #[cfg(test)] mod tests; +/// A helper trait (mostly exposed for testing). It provides the four stats getters needed by +/// [`DataSkippingStatsProvider`]. From there, we can automatically derive a +/// [`DataSkippingPredicateEvaluator`]. +pub(crate) trait ParquetStatsProvider { + /// The min-value stat for this column, if the column exists in this file, has the expected + /// type, and the parquet footer provides stats for it. + fn get_parquet_min_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option; + + /// The max-value stat for this column, if the column exists in this file, has the expected + /// type, and the parquet footer provides stats for it. + fn get_parquet_max_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option; + + /// The nullcount stat for this column, if the column exists in this file, has the expected + /// type, and the parquet footer provides stats for it. + fn get_parquet_nullcount_stat(&self, _col: &ColumnName) -> Option; + + /// The rowcount stat for this row group. It is always available in the parquet footer. + fn get_parquet_rowcount_stat(&self) -> i64; +} + +/// Blanket implementation that converts a [`ParquetStatsProvider`] into a +/// [`DataSkippingPredicateEvaluator`]. +impl DataSkippingPredicateEvaluator for T { + type Output = bool; + type TypedStat = Scalar; + type IntStat = i64; + + fn get_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option { + self.get_parquet_min_stat(col, data_type) + } + + fn get_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option { + self.get_parquet_max_stat(col, data_type) + } + + fn get_nullcount_stat(&self, col: &ColumnName) -> Option { + self.get_parquet_nullcount_stat(col) + } + + fn get_rowcount_stat(&self) -> Option { + Some(self.get_parquet_rowcount_stat()) + } + + fn eval_partial_cmp( + &self, + ord: Ordering, + col: Scalar, + val: &Scalar, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::partial_cmp_scalars(ord, &col, val, inverted) + } + + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { + PredicateEvaluatorDefaults::eval_scalar(val, inverted) + } + + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + let safe_to_skip = match inverted { + true => self.get_rowcount_stat()?, // all-null + false => 0i64, // no-null + }; + Some(self.get_nullcount_stat(col)? != safe_to_skip) + } + + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) + } + + fn finish_eval_variadic( + &self, + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::finish_eval_variadic(op, exprs, inverted) + } +} + /// Data skipping based on parquet footer stats (e.g. row group skipping). The required methods /// fetch stats values for requested columns (if available and with compatible types), and the /// provided methods implement the actual skipping logic. @@ -18,18 +105,6 @@ mod tests; /// methods of this class convert various supported expressions into data skipping predicates, and /// then return the result of evaluating the translated filter. pub(crate) trait ParquetStatsSkippingFilter { - /// Retrieves the minimum value of a column, if it exists and has the requested type. - fn get_min_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option; - - /// Retrieves the maximum value of a column, if it exists and has the requested type. - fn get_max_stat_value(&self, col: &ColumnName, data_type: &DataType) -> Option; - - /// Retrieves the null count of a column, if it exists. - fn get_nullcount_stat_value(&self, col: &ColumnName) -> Option; - - /// Retrieves the row count of a column (parquet footers always include this stat). - fn get_rowcount_stat_value(&self) -> i64; - /// Attempts to filter using SQL WHERE semantics. /// /// By default, [`apply_expr`] can produce unwelcome behavior for comparisons involving all-NULL @@ -73,23 +148,28 @@ pub(crate) trait ParquetStatsSkippingFilter { /// If the result was FALSE, it forces both inner and outer AND to FALSE, as desired. If the /// result was TRUE or NULL, then it does not contribute to data skipping but also does not /// block it if other legs of the AND evaluate to FALSE. - fn apply_sql_where(&self, filter: &Expression) -> Option { - use Expression::*; - use VariadicOperator::And; + // TODO: If these are generally useful, we may want to move them into PredicateEvaluator? + fn eval_sql_where(&self, filter: &Expr) -> Option; + fn eval_binary_nullsafe(&self, op: BinaryOperator, left: &Expr, right: &Expr) -> Option; +} + +impl> ParquetStatsSkippingFilter for T { + fn eval_sql_where(&self, filter: &Expr) -> Option { + use Expr::{BinaryOperation, VariadicOperation}; match filter { - VariadicOperation { op: And, exprs } => { + VariadicOperation { op: VariadicOperator::And, exprs } => { let exprs: Vec<_> = exprs .iter() - .map(|expr| self.apply_sql_where(expr)) + .map(|expr| self.eval_sql_where(expr)) .map(|result| match result { - Some(value) => Expression::literal(value), - None => Expression::null_literal(DataType::BOOLEAN), + Some(value) => Expr::literal(value), + None => Expr::null_literal(DataType::BOOLEAN), }) .collect(); - self.apply_variadic(And, &exprs, false) + self.eval_variadic(VariadicOperator::And, &exprs, false) } - BinaryOperation { op, left, right } => self.apply_binary_nullsafe(*op, left, right), - _ => self.apply_expr(filter, false), + BinaryOperation { op, left, right } => self.eval_binary_nullsafe(*op, left, right), + _ => self.eval_expr(filter, false), } } @@ -103,293 +183,16 @@ pub(crate) trait ParquetStatsSkippingFilter { /// fails (producing FALSE), it short-circuits the entire AND without ever evaluating the /// comparison. Otherwise, the original comparison will run and -- if FALSE -- can cause data /// skipping as usual. - fn apply_binary_nullsafe( - &self, - op: BinaryOperator, - left: &Expression, - right: &Expression, - ) -> Option { + fn eval_binary_nullsafe(&self, op: BinaryOperator, left: &Expr, right: &Expr) -> Option { use UnaryOperator::IsNull; // Convert `a {cmp} b` to `AND(a IS NOT NULL, b IS NOT NULL, a {cmp} b)`, // and only evaluate the comparison if the null checks don't short circuit. - if let Some(false) = self.apply_unary(IsNull, left, true) { + if let Some(false) = self.eval_unary(IsNull, left, true) { return Some(false); } - if let Some(false) = self.apply_unary(IsNull, right, true) { + if let Some(false) = self.eval_unary(IsNull, right, true) { return Some(false); } - self.apply_binary(op, left, right, false) - } - - /// Evaluates a predicate over stats instead of rows. Evaluation is a depth-first traversal over - /// all supported subexpressions; unsupported expressions (or expressions that rely on missing - /// stats) are replaced with NULL (`None`) values, which then propagate upward following the - /// NULL semantics of their parent expressions. If stats prove the filter would eliminate ALL - /// rows from the result, then this method returns `Some(false)` and those rows can be skipped - /// without inspecting them individually. A return value of `Some(true)` means the filter does - /// not reliably eliminate all rows, and `None` indicates the needed stats were not available. - /// - /// If `inverted`, the caller requests to evaluate `NOT(expression)` instead of evaluating - /// `expression` directly. This is important because `NOT(data_skipping(expr))` is NOT - /// `equivalent to data_skipping(NOT(expr))`, so we need to "push down" the NOT in order to - /// ensure correct semantics. For example, given the expression `x == 10`, and min-max stats - /// 1..100, `NOT(x == 10)` and `x == 10` both evaluate to TRUE (because neither filter can - /// provably eliminate all rows). - fn apply_expr(&self, expression: &Expression, inverted: bool) -> Option { - use Expression::*; - match expression { - VariadicOperation { op, exprs } => self.apply_variadic(*op, exprs, inverted), - BinaryOperation { op, left, right } => self.apply_binary(*op, left, right, inverted), - UnaryOperation { op, expr } => self.apply_unary(*op, expr, inverted), - Literal(value) => Self::apply_scalar(value, inverted), - Column(col) => self.apply_column(col, inverted), - Struct(_) => None, // not supported - } - } - - /// Evaluates AND/OR expressions with Kleene semantics and short circuit behavior. - /// - /// Short circuiting is based on the observation that each operator has a "dominant" boolean - /// value that forces the output to match regardless of any other input. For example, a single - /// FALSE input forces AND to FALSE, and a single TRUE input forces OR to TRUE. - /// - /// Kleene semantics mean that -- in the absence of any dominant input -- a single NULL input - /// forces the output to NULL. If no NULL nor dominant input is seen, then the operator's output - /// "defaults" to the non-dominant value (and we can actually just ignore non-dominant inputs). - /// - /// If the filter is inverted, use de Morgan's laws to push the inversion down into the inputs - /// (e.g. `NOT(AND(a, b))` becomes `OR(NOT(a), NOT(b))`). - fn apply_variadic( - &self, - op: VariadicOperator, - exprs: &[Expression], - inverted: bool, - ) -> Option { - // With AND (OR), any FALSE (TRUE) input forces FALSE (TRUE) output. If there was no - // dominating input, then any NULL input forces NULL output. Otherwise, return the - // non-dominant value. Inverting the operation also inverts the dominant value. - let dominator = match op { - VariadicOperator::And => inverted, - VariadicOperator::Or => !inverted, - }; - - // Evaluate the input expressions, inverting each as needed and tracking whether we've seen - // any NULL result. Stop immediately (short circuit) if we see a dominant value. - let result = exprs.iter().try_fold(false, |found_null, expr| { - match self.apply_expr(expr, inverted) { - Some(v) if v == dominator => None, // (1) short circuit, dominant found - Some(_) => Some(found_null), - None => Some(true), // (2) null found (but keep looking for a dominant value) - } - }); - - match result { - None => Some(dominator), // (1) short circuit, dominant found - Some(false) => Some(!dominator), - Some(true) => None, // (2) null found, dominant not found - } - } - - /// Evaluates binary comparisons. Any NULL input produces a NULL output. If `inverted`, the - /// opposite operation is performed, e.g. `<` evaluates as if `>=` had been requested instead. - fn apply_binary( - &self, - op: BinaryOperator, - left: &Expression, - right: &Expression, - inverted: bool, - ) -> Option { - use BinaryOperator::*; - use Expression::{Column, Literal}; - - // Min/Max stats don't allow us to push inversion down into the comparison. Instead, we - // invert the comparison itself when needed and compute normally after that. - let op = match inverted { - true => op.invert()?, - false => op, - }; - - // NOTE: We rely on the literal values to provide logical type hints. That means we cannot - // perform column-column comparisons, because we cannot infer the logical type to use. - let (op, col, val) = match (left, right) { - (Column(col), Literal(val)) => (op, col, val), - (Literal(val), Column(col)) => (op.commute()?, col, val), - (Literal(a), Literal(b)) => return Self::apply_binary_scalars(op, a, b), - _ => { - info!("Unsupported binary operand(s): {left:?} {op:?} {right:?}"); - return None; - } - }; - let min_max_disjunct = |min_ord, max_ord, inverted| -> Option { - let skip_lo = self.partial_cmp_min_stat(col, val, min_ord, false)?; - let skip_hi = self.partial_cmp_max_stat(col, val, max_ord, false)?; - let skip = skip_lo || skip_hi; - Some(skip != inverted) - }; - match op { - // Given `col == val`: - // skip if `val` cannot equal _any_ value in [min, max], implies - // skip if `NOT(val BETWEEN min AND max)` implies - // skip if `NOT(min <= val AND val <= max)` implies - // skip if `min > val OR max < val` - // keep if `NOT(min > val OR max < val)` - Equal => min_max_disjunct(Ordering::Greater, Ordering::Less, true), - // Given `col != val`: - // skip if `val` equals _every_ value in [min, max], implies - // skip if `val == min AND val == max` implies - // skip if `val <= min AND min <= val AND val <= max AND max <= val` implies - // skip if `val <= min AND max <= val` implies - // keep if `NOT(val <= min AND max <= val)` implies - // keep if `val > min OR max > val` implies - // keep if `min < val OR max > val` - NotEqual => min_max_disjunct(Ordering::Less, Ordering::Greater, false), - // Given `col < val`: - // Skip if `val` is not greater than _all_ values in [min, max], implies - // Skip if `val <= min AND val <= max` implies - // Skip if `val <= min` implies - // Keep if `NOT(val <= min)` implies - // Keep if `val > min` implies - // Keep if `min < val` - LessThan => self.partial_cmp_min_stat(col, val, Ordering::Less, false), - // Given `col <= val`: - // Skip if `val` is less than _all_ values in [min, max], implies - // Skip if `val < min AND val < max` implies - // Skip if `val < min` implies - // Keep if `NOT(val < min)` implies - // Keep if `NOT(min > val)` - LessThanOrEqual => self.partial_cmp_min_stat(col, val, Ordering::Greater, true), - // Given `col > val`: - // Skip if `val` is not less than _all_ values in [min, max], implies - // Skip if `val >= min AND val >= max` implies - // Skip if `val >= max` implies - // Keep if `NOT(val >= max)` implies - // Keep if `NOT(max <= val)` implies - // Keep if `max > val` - GreaterThan => self.partial_cmp_max_stat(col, val, Ordering::Greater, false), - // Given `col >= val`: - // Skip if `val is greater than _every_ value in [min, max], implies - // Skip if `val > min AND val > max` implies - // Skip if `val > max` implies - // Keep if `NOT(val > max)` implies - // Keep if `NOT(max < val)` - GreaterThanOrEqual => self.partial_cmp_max_stat(col, val, Ordering::Less, true), - _ => { - info!("Unsupported binary operator: {left:?} {op:?} {right:?}"); - None - } - } - } - - /// Helper method, invoked by [`apply_binary`], for constant comparisons. Query planner constant - /// folding optimizationss SHOULD eliminate such patterns, but we implement the support anyway - /// because propagating a NULL in its place could disable skipping entirely, e.g. an expression - /// such as `OR(10 == 20, )`. - fn apply_binary_scalars(op: BinaryOperator, left: &Scalar, right: &Scalar) -> Option { - use BinaryOperator::*; - match op { - Equal => partial_cmp_scalars(left, right, Ordering::Equal, false), - NotEqual => partial_cmp_scalars(left, right, Ordering::Equal, true), - LessThan => partial_cmp_scalars(left, right, Ordering::Less, false), - LessThanOrEqual => partial_cmp_scalars(left, right, Ordering::Greater, true), - GreaterThan => partial_cmp_scalars(left, right, Ordering::Greater, false), - GreaterThanOrEqual => partial_cmp_scalars(left, right, Ordering::Less, true), - _ => { - info!("Unsupported binary operator: {left:?} {op:?} {right:?}"); - None - } - } - } - - /// Applies unary NOT and IS [NOT] NULL. Null inputs to NOT produce NULL output. The null checks - /// are only defined for columns (not expressions), and they ony produce NULL output if the - /// necessary nullcount stats are missing. - fn apply_unary(&self, op: UnaryOperator, expr: &Expression, inverted: bool) -> Option { - match op { - UnaryOperator::Not => self.apply_expr(expr, !inverted), - UnaryOperator::IsNull => match expr { - Expression::Column(col) => { - let skip = match inverted { - // IS NOT NULL - skip if all-null - true => self.get_rowcount_stat_value(), - // IS NULL - skip if no-null - false => 0, - }; - Some(self.get_nullcount_stat_value(col)? != skip) - } - _ => { - info!("Unsupported unary operation: {op:?}({expr:?})"); - None - } - }, - } - } - - /// Propagates a boolean-typed column, allowing e.g. `flag OR ...`. - /// Columns of other types are ignored (NULL result). - fn apply_column(&self, col: &ColumnName, inverted: bool) -> Option { - let as_boolean = |get: &dyn Fn(_, _, _) -> _| match get(self, col, &DataType::BOOLEAN) { - Some(Scalar::Boolean(value)) => Some(value), - Some(_) => { - info!("Ignoring non-boolean column {col}"); - None - } - _ => None, - }; - let min = as_boolean(&Self::get_min_stat_value)?; - let max = as_boolean(&Self::get_max_stat_value)?; - Some(min != inverted || max != inverted) - } - - /// Propagates a boolean literal, allowing e.g. `FALSE OR ...`. - /// Literals of other types are ignored (NULL result). - fn apply_scalar(value: &Scalar, inverted: bool) -> Option { - match value { - Scalar::Boolean(value) => Some(*value != inverted), - _ => { - info!("Ignoring non-boolean literal {value}"); - None - } - } - } - - /// Performs a partial comparison against a column min-stat. See [`partial_cmp_scalars`] for - /// details of the comparison semantics. - fn partial_cmp_min_stat( - &self, - col: &ColumnName, - val: &Scalar, - ord: Ordering, - inverted: bool, - ) -> Option { - let min = self.get_min_stat_value(col, &val.data_type())?; - partial_cmp_scalars(&min, val, ord, inverted) - } - - /// Performs a partial comparison against a column max-stat. See [`partial_cmp_scalars`] for - /// details of the comparison semantics. - fn partial_cmp_max_stat( - &self, - col: &ColumnName, - val: &Scalar, - ord: Ordering, - inverted: bool, - ) -> Option { - let max = self.get_max_stat_value(col, &val.data_type())?; - partial_cmp_scalars(&max, val, ord, inverted) + self.eval_binary(op, left, right, false) } } - -/// Compares two scalar values, returning Some(true) if the result matches the target `Ordering`. If -/// an inverted comparison is requested, then return Some(false) on match instead. For example, -/// requesting an inverted `Ordering::Less` matches both `Ordering::Greater` and `Ordering::Equal`, -/// corresponding to a logical `>=` comparison. Returns `None` if the values are incomparable, which -/// can occur because the types differ or because the type itself is incomparable. -pub(crate) fn partial_cmp_scalars( - a: &Scalar, - b: &Scalar, - ord: Ordering, - inverted: bool, -) -> Option { - let result = a.partial_cmp(b)? == ord; - Some(result != inverted) -} diff --git a/kernel/src/engine/parquet_stats_skipping/tests.rs b/kernel/src/engine/parquet_stats_skipping/tests.rs index bc8ce1e78..9d35fac5a 100644 --- a/kernel/src/engine/parquet_stats_skipping/tests.rs +++ b/kernel/src/engine/parquet_stats_skipping/tests.rs @@ -1,32 +1,11 @@ use super::*; -use crate::expressions::{column_expr, column_name, ArrayData, StructData}; -use crate::schema::ArrayType; +use crate::expressions::{column_expr, Expression as Expr}; +use crate::predicates::PredicateEvaluator; use crate::DataType; -struct UnimplementedTestFilter; -impl ParquetStatsSkippingFilter for UnimplementedTestFilter { - fn get_min_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { - unimplemented!() - } - - fn get_max_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { - unimplemented!() - } - - fn get_nullcount_stat_value(&self, _col: &ColumnName) -> Option { - unimplemented!() - } - - fn get_rowcount_stat_value(&self) -> i64 { - unimplemented!() - } -} - -struct JunctionTest { - inputs: &'static [Option], - expect_and: Option, - expect_or: Option, -} +const TRUE: Option = Some(true); +const FALSE: Option = Some(false); +const NULL: Option = None; macro_rules! expect_eq { ( $expr: expr, $expect: expr, $fmt: literal ) => { @@ -41,259 +20,104 @@ macro_rules! expect_eq { ); }; } -impl JunctionTest { - fn new( - inputs: &'static [Option], - expect_and: Option, - expect_or: Option, - ) -> Self { - Self { - inputs, - expect_and, - expect_or, - } + +struct UnimplementedTestFilter; +impl ParquetStatsProvider for UnimplementedTestFilter { + fn get_parquet_min_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { + unimplemented!() } - fn do_test(&self) { - use VariadicOperator::*; - let filter = UnimplementedTestFilter; - let inputs: Vec<_> = self - .inputs - .iter() - .map(|val| match val { - Some(v) => Expression::literal(v), - None => Expression::null_literal(DataType::BOOLEAN), - }) - .collect(); - expect_eq!( - filter.apply_variadic(And, &inputs, false), - self.expect_and, - "AND({inputs:?})" - ); - expect_eq!( - filter.apply_variadic(Or, &inputs, false), - self.expect_or, - "OR({inputs:?})" - ); - expect_eq!( - filter.apply_variadic(And, &inputs, true), - self.expect_and.map(|val| !val), - "NOT(AND({inputs:?}))" - ); - expect_eq!( - filter.apply_variadic(Or, &inputs, true), - self.expect_or.map(|val| !val), - "NOT(OR({inputs:?}))" - ); + fn get_parquet_max_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { + unimplemented!() + } + + fn get_parquet_nullcount_stat(&self, _col: &ColumnName) -> Option { + unimplemented!() + } + + fn get_parquet_rowcount_stat(&self) -> i64 { + unimplemented!() } } /// Tests apply_variadic and apply_scalar #[test] fn test_junctions() { - let test_case = JunctionTest::new; - const TRUE: Option = Some(true); - const FALSE: Option = Some(false); - const NULL: Option = None; + use VariadicOperator::*; + + let test_cases = &[ // Every combo of 0, 1 and 2 inputs - test_case(&[], TRUE, FALSE), - test_case(&[TRUE], TRUE, TRUE), - test_case(&[FALSE], FALSE, FALSE), - test_case(&[NULL], NULL, NULL), - test_case(&[TRUE, TRUE], TRUE, TRUE), - test_case(&[TRUE, FALSE], FALSE, TRUE), - test_case(&[TRUE, NULL], NULL, TRUE), - test_case(&[FALSE, TRUE], FALSE, TRUE), - test_case(&[FALSE, FALSE], FALSE, FALSE), - test_case(&[FALSE, NULL], FALSE, NULL), - test_case(&[NULL, TRUE], NULL, TRUE), - test_case(&[NULL, FALSE], FALSE, NULL), - test_case(&[NULL, NULL], NULL, NULL), + (&[] as &[Option], TRUE, FALSE), + (&[TRUE], TRUE, TRUE), + (&[FALSE], FALSE, FALSE), + (&[NULL], NULL, NULL), + (&[TRUE, TRUE], TRUE, TRUE), + (&[TRUE, FALSE], FALSE, TRUE), + (&[TRUE, NULL], NULL, TRUE), + (&[FALSE, TRUE], FALSE, TRUE), + (&[FALSE, FALSE], FALSE, FALSE), + (&[FALSE, NULL], FALSE, NULL), + (&[NULL, TRUE], NULL, TRUE), + (&[NULL, FALSE], FALSE, NULL), + (&[NULL, NULL], NULL, NULL), // Every combo of 1:2 - test_case(&[TRUE, FALSE, FALSE], FALSE, TRUE), - test_case(&[FALSE, TRUE, FALSE], FALSE, TRUE), - test_case(&[FALSE, FALSE, TRUE], FALSE, TRUE), - test_case(&[TRUE, NULL, NULL], NULL, TRUE), - test_case(&[NULL, TRUE, NULL], NULL, TRUE), - test_case(&[NULL, NULL, TRUE], NULL, TRUE), - test_case(&[FALSE, TRUE, TRUE], FALSE, TRUE), - test_case(&[TRUE, FALSE, TRUE], FALSE, TRUE), - test_case(&[TRUE, TRUE, FALSE], FALSE, TRUE), - test_case(&[FALSE, NULL, NULL], FALSE, NULL), - test_case(&[NULL, FALSE, NULL], FALSE, NULL), - test_case(&[NULL, NULL, FALSE], FALSE, NULL), - test_case(&[NULL, TRUE, TRUE], NULL, TRUE), - test_case(&[TRUE, NULL, TRUE], NULL, TRUE), - test_case(&[TRUE, TRUE, NULL], NULL, TRUE), - test_case(&[NULL, FALSE, FALSE], FALSE, NULL), - test_case(&[FALSE, NULL, FALSE], FALSE, NULL), - test_case(&[FALSE, FALSE, NULL], FALSE, NULL), + (&[TRUE, FALSE, FALSE], FALSE, TRUE), + (&[FALSE, TRUE, FALSE], FALSE, TRUE), + (&[FALSE, FALSE, TRUE], FALSE, TRUE), + (&[TRUE, NULL, NULL], NULL, TRUE), + (&[NULL, TRUE, NULL], NULL, TRUE), + (&[NULL, NULL, TRUE], NULL, TRUE), + (&[FALSE, TRUE, TRUE], FALSE, TRUE), + (&[TRUE, FALSE, TRUE], FALSE, TRUE), + (&[TRUE, TRUE, FALSE], FALSE, TRUE), + (&[FALSE, NULL, NULL], FALSE, NULL), + (&[NULL, FALSE, NULL], FALSE, NULL), + (&[NULL, NULL, FALSE], FALSE, NULL), + (&[NULL, TRUE, TRUE], NULL, TRUE), + (&[TRUE, NULL, TRUE], NULL, TRUE), + (&[TRUE, TRUE, NULL], NULL, TRUE), + (&[NULL, FALSE, FALSE], FALSE, NULL), + (&[FALSE, NULL, FALSE], FALSE, NULL), + (&[FALSE, FALSE, NULL], FALSE, NULL), // Every unique ordering of 3 - test_case(&[TRUE, FALSE, NULL], FALSE, TRUE), - test_case(&[TRUE, NULL, FALSE], FALSE, TRUE), - test_case(&[FALSE, TRUE, NULL], FALSE, TRUE), - test_case(&[FALSE, NULL, TRUE], FALSE, TRUE), - test_case(&[NULL, TRUE, FALSE], FALSE, TRUE), - test_case(&[NULL, FALSE, TRUE], FALSE, TRUE), + (&[TRUE, FALSE, NULL], FALSE, TRUE), + (&[TRUE, NULL, FALSE], FALSE, TRUE), + (&[FALSE, TRUE, NULL], FALSE, TRUE), + (&[FALSE, NULL, TRUE], FALSE, TRUE), + (&[NULL, TRUE, FALSE], FALSE, TRUE), + (&[NULL, FALSE, TRUE], FALSE, TRUE), ]; - for test_case in test_cases { - test_case.do_test(); - } -} - -// tests apply_binary_scalars -#[test] -fn test_binary_scalars() { - use Scalar::*; - let smaller_values = &[ - Integer(1), - Long(1), - Short(1), - Byte(1), - Float(1.0), - Double(1.0), - String("1".into()), - Boolean(false), - Timestamp(1), - TimestampNtz(1), - Date(1), - Binary(vec![1]), - Decimal(1, 10, 10), // invalid value, - Null(DataType::LONG), - Struct(StructData::try_new(vec![], vec![]).unwrap()), - Array(ArrayData::new( - ArrayType::new(DataType::LONG, false), - Vec::::new(), - )), - ]; - let larger_values = &[ - Integer(10), - Long(10), - Short(10), - Byte(10), - Float(10.0), - Double(10.0), - String("10".into()), - Boolean(true), - Timestamp(10), - TimestampNtz(10), - Date(10), - Binary(vec![10]), - Decimal(10, 10, 10), // invalid value - Null(DataType::LONG), - Struct(StructData::try_new(vec![], vec![]).unwrap()), - Array(ArrayData::new( - ArrayType::new(DataType::LONG, false), - Vec::::new(), - )), - ]; - - // scalars of different types are always incomparable - use BinaryOperator::*; - let binary_ops = [ - Equal, - NotEqual, - LessThan, - LessThanOrEqual, - GreaterThan, - GreaterThanOrEqual, - ]; - let compare = UnimplementedTestFilter::apply_binary_scalars; - for (i, a) in smaller_values.iter().enumerate() { - for b in smaller_values.iter().skip(i + 1) { - for op in binary_ops { - let result = compare(op, a, b); - let a_type = a.data_type(); - let b_type = b.data_type(); - assert!( - result.is_none(), - "{a_type:?} should not be comparable to {b_type:?}" - ); - } - } - } - - let expect_if_comparable_type = |s: &_, expect| match s { - Null(_) | Decimal(..) | Struct(_) | Array(_) => None, - _ => Some(expect), - }; - - // Test same-type comparisons where a == b - for (a, b) in smaller_values.iter().zip(smaller_values.iter()) { - expect_eq!( - compare(Equal, a, b), - expect_if_comparable_type(a, true), - "{a:?} == {b:?}" - ); - expect_eq!( - compare(NotEqual, a, b), - expect_if_comparable_type(a, false), - "{a:?} != {b:?}" - ); - - expect_eq!( - compare(LessThan, a, b), - expect_if_comparable_type(a, false), - "{a:?} < {b:?}" - ); - - expect_eq!( - compare(GreaterThan, a, b), - expect_if_comparable_type(a, false), - "{a:?} > {b:?}" - ); - - expect_eq!( - compare(LessThanOrEqual, a, b), - expect_if_comparable_type(a, true), - "{a:?} <= {b:?}" - ); - - expect_eq!( - compare(GreaterThanOrEqual, a, b), - expect_if_comparable_type(a, true), - "{a:?} >= {b:?}" - ); - } - - // Test same-type comparisons where a < b - for (a, b) in smaller_values.iter().zip(larger_values.iter()) { - expect_eq!( - compare(Equal, a, b), - expect_if_comparable_type(a, false), - "{a:?} == {b:?}" - ); - - expect_eq!( - compare(NotEqual, a, b), - expect_if_comparable_type(a, true), - "{a:?} != {b:?}" - ); + let filter = UnimplementedTestFilter; + for (inputs, expect_and, expect_or) in test_cases { + let inputs: Vec<_> = inputs + .iter() + .map(|val| match val { + Some(v) => Expr::literal(v), + None => Expr::null_literal(DataType::BOOLEAN), + }) + .collect(); expect_eq!( - compare(LessThan, a, b), - expect_if_comparable_type(a, true), - "{a:?} < {b:?}" + filter.eval_variadic(And, &inputs, false), + *expect_and, + "AND({inputs:?})" ); - expect_eq!( - compare(GreaterThan, a, b), - expect_if_comparable_type(a, false), - "{a:?} > {b:?}" + filter.eval_variadic(Or, &inputs, false), + *expect_or, + "OR({inputs:?})" ); - expect_eq!( - compare(LessThanOrEqual, a, b), - expect_if_comparable_type(a, true), - "{a:?} <= {b:?}" + filter.eval_variadic(And, &inputs, true), + expect_and.map(|val| !val), + "NOT(AND({inputs:?}))" ); - expect_eq!( - compare(GreaterThanOrEqual, a, b), - expect_if_comparable_type(a, false), - "{a:?} >= {b:?}" + filter.eval_variadic(Or, &inputs, true), + expect_or.map(|val| !val), + "NOT(OR({inputs:?}))" ); } } @@ -312,394 +136,70 @@ impl MinMaxTestFilter { .cloned() } } -impl ParquetStatsSkippingFilter for MinMaxTestFilter { - fn get_min_stat_value(&self, _col: &ColumnName, data_type: &DataType) -> Option { +impl ParquetStatsProvider for MinMaxTestFilter { + fn get_parquet_min_stat(&self, _col: &ColumnName, data_type: &DataType) -> Option { Self::get_stat_value(&self.min, data_type) } - fn get_max_stat_value(&self, _col: &ColumnName, data_type: &DataType) -> Option { + fn get_parquet_max_stat(&self, _col: &ColumnName, data_type: &DataType) -> Option { Self::get_stat_value(&self.max, data_type) } - fn get_nullcount_stat_value(&self, _col: &ColumnName) -> Option { + fn get_parquet_nullcount_stat(&self, _col: &ColumnName) -> Option { unimplemented!() } - fn get_rowcount_stat_value(&self) -> i64 { + fn get_parquet_rowcount_stat(&self) -> i64 { unimplemented!() } } -#[test] -fn test_binary_eq_ne() { - use BinaryOperator::*; - - const LO: Scalar = Scalar::Long(1); - const MID: Scalar = Scalar::Long(10); - const HI: Scalar = Scalar::Long(100); - let col = &column_expr!("x"); - - for inverted in [false, true] { - // negative test -- mismatched column type - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - Equal, - col, - &Expression::literal("10"), - inverted, - ), - None, - "{col} == '10' (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - // quick test for literal-literal comparisons - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - Equal, - &MID.into(), - &MID.into(), - inverted, - ), - Some(!inverted), - "{MID} == {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - // quick test for literal-column comparisons - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - Equal, - &MID.into(), - col, - inverted, - ), - Some(!inverted), - "{MID} == {col} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - Equal, - col, - &MID.into(), - inverted, - ), - Some(!inverted), - "{col} == {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - Equal, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both EQ and NE - "{col} == {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - Equal, - col, - &HI.into(), - inverted, - ), - Some(inverted), - "{col} == {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - Equal, - col, - &LO.into(), - inverted, - ), - Some(inverted), - "{col} == {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); - - // negative test -- mismatched column type - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - NotEqual, - col, - &Expression::literal("10"), - inverted, - ), - None, - "{col} != '10' (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - NotEqual, - col, - &MID.into(), - inverted, - ), - Some(inverted), - "{col} != {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - NotEqual, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both EQ and NE - "{col} != {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - NotEqual, - col, - &HI.into(), - inverted, - ), - Some(!inverted), - "{col} != {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - NotEqual, - col, - &LO.into(), - inverted, - ), - Some(!inverted), - "{col} != {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); - } -} - -#[test] -fn test_binary_lt_ge() { - use BinaryOperator::*; - - const LO: Scalar = Scalar::Long(1); - const MID: Scalar = Scalar::Long(10); - const HI: Scalar = Scalar::Long(100); - let col = &column_expr!("x"); - - for inverted in [false, true] { - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - LessThan, - col, - &MID.into(), - inverted, - ), - Some(inverted), - "{col} < {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - LessThan, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both LT and GE - "{col} < {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - LessThan, - col, - &HI.into(), - inverted, - ), - Some(!inverted), - "{col} < {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - LessThan, - col, - &LO.into(), - inverted, - ), - Some(inverted), - "{col} < {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - GreaterThanOrEqual, - col, - &MID.into(), - inverted, - ), - Some(!inverted), - "{col} >= {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - GreaterThanOrEqual, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both EQ and NE - "{col} >= {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - GreaterThanOrEqual, - col, - &HI.into(), - inverted, - ), - Some(inverted), - "{col} >= {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - GreaterThanOrEqual, - col, - &LO.into(), - inverted, - ), - Some(!inverted), - "{col} >= {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); - } -} #[test] -fn test_binary_le_gt() { - use BinaryOperator::*; - - const LO: Scalar = Scalar::Long(1); - const MID: Scalar = Scalar::Long(10); - const HI: Scalar = Scalar::Long(100); - let col = &column_expr!("x"); - - for inverted in [false, true] { - // negative test -- mismatched column type - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - LessThanOrEqual, - col, - &Expression::literal("10"), - inverted, - ), - None, - "{col} <= '10' (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - LessThanOrEqual, - col, - &MID.into(), - inverted, - ), - Some(!inverted), - "{col} <= {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - LessThanOrEqual, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both LT and GE - "{col} <= {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); - - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - LessThanOrEqual, - col, - &HI.into(), - inverted, - ), - Some(!inverted), - "{col} <= {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); +fn test_eval_binary_comparisons() { + const FIVE: Scalar = Scalar::Integer(5); + const TEN: Scalar = Scalar::Integer(10); + const FIFTEEN: Scalar = Scalar::Integer(15); + const NULL_VAL: Scalar = Scalar::Null(DataType::INTEGER); + + let expressions = [ + Expr::lt(column_expr!("x"), 10), + Expr::le(column_expr!("x"), 10), + Expr::eq(column_expr!("x"), 10), + Expr::ne(column_expr!("x"), 10), + Expr::gt(column_expr!("x"), 10), + Expr::ge(column_expr!("x"), 10), + ]; - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - LessThanOrEqual, - col, - &LO.into(), - inverted, - ), - Some(inverted), - "{col} <= {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); + let do_test = |min: Scalar, max: Scalar, expected: &[Option]| { + let filter = MinMaxTestFilter::new(Some(min.clone()), Some(max.clone())); + for (expr, expect) in expressions.iter().zip(expected.iter()) { + expect_eq!( + filter.eval_expr(expr, false), + *expect, + "{expr:#?} with [{min}..{max}]" + ); + } + }; - // negative test -- mismatched column type - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - GreaterThan, - col, - &Expression::literal("10"), - inverted, - ), - None, - "{col} > '10' (min: {MID}, max: {MID}, inverted: {inverted})" - ); + // value < min = max (15..15 = 10, 15..15 <= 10, etc) + do_test(FIFTEEN, FIFTEEN, &[FALSE, FALSE, FALSE, TRUE, TRUE, TRUE]); - expect_eq!( - MinMaxTestFilter::new(MID.into(), MID.into()).apply_binary( - GreaterThan, - col, - &MID.into(), - inverted, - ), - Some(inverted), - "{col} > {MID} (min: {MID}, max: {MID}, inverted: {inverted})" - ); + // min = max = value (10..10 = 10, 10..10 <= 10, etc) + // + // NOTE: missing min or max stat produces NULL output if the expression needed it. + do_test(TEN, TEN, &[FALSE, TRUE, TRUE, FALSE, FALSE, TRUE]); + do_test(NULL_VAL, TEN, &[NULL, NULL, NULL, NULL, FALSE, TRUE]); + do_test(TEN, NULL_VAL, &[FALSE, TRUE, NULL, NULL, NULL, NULL]); - expect_eq!( - MinMaxTestFilter::new(LO.into(), HI.into()).apply_binary( - GreaterThan, - col, - &MID.into(), - inverted, - ), - Some(true), // min..max range includes both EQ and NE - "{col} > {MID} (min: {LO}, max: {HI}, inverted: {inverted})" - ); + // min = max < value (5..5 = 10, 5..5 <= 10, etc) + do_test(FIVE, FIVE, &[TRUE, TRUE, FALSE, TRUE, FALSE, FALSE]); - expect_eq!( - MinMaxTestFilter::new(LO.into(), MID.into()).apply_binary( - GreaterThan, - col, - &HI.into(), - inverted, - ), - Some(inverted), - "{col} > {HI} (min: {LO}, max: {MID}, inverted: {inverted})" - ); + // value = min < max (5..15 = 10, 5..15 <= 10, etc) + do_test(TEN, FIFTEEN, &[FALSE, TRUE, TRUE, TRUE, TRUE, TRUE]); - expect_eq!( - MinMaxTestFilter::new(MID.into(), HI.into()).apply_binary( - GreaterThan, - col, - &LO.into(), - inverted, - ), - Some(!inverted), - "{col} > {LO} (min: {MID}, max: {HI}, inverted: {inverted})" - ); - } + // min < value < max (5..15 = 10, 5..15 <= 10, etc) + do_test(FIVE, FIFTEEN, &[TRUE, TRUE, TRUE, TRUE, TRUE, TRUE]); } struct NullCountTestFilter { @@ -714,98 +214,67 @@ impl NullCountTestFilter { } } } -impl ParquetStatsSkippingFilter for NullCountTestFilter { - fn get_min_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { +impl ParquetStatsProvider for NullCountTestFilter { + fn get_parquet_min_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { unimplemented!() } - fn get_max_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { + fn get_parquet_max_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { unimplemented!() } - fn get_nullcount_stat_value(&self, _col: &ColumnName) -> Option { + fn get_parquet_nullcount_stat(&self, _col: &ColumnName) -> Option { self.nullcount } - fn get_rowcount_stat_value(&self) -> i64 { + fn get_parquet_rowcount_stat(&self) -> i64 { self.rowcount } } #[test] -fn test_not_null() { - use UnaryOperator::IsNull; - - let col = &column_expr!("x"); - for inverted in [false, true] { - expect_eq!( - NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted), - None, - "{col} IS NULL (nullcount: None, rowcount: 10, inverted: {inverted})" - ); +fn test_eval_is_null() { + let expressions = [ + Expr::is_null(column_expr!("x")), + !Expr::is_null(column_expr!("x")) + ]; - expect_eq!( - NullCountTestFilter::new(Some(0), 10).apply_unary(IsNull, col, inverted), - Some(inverted), - "{col} IS NULL (nullcount: 0, rowcount: 10, inverted: {inverted})" - ); + let do_test = |nullcount: i64, expected: &[Option]| { + let filter = NullCountTestFilter::new(Some(nullcount), 2); + for (expr, expect) in expressions.iter().zip(expected) { + expect_eq!( + filter.eval_expr(expr, false), + *expect, + "{expr:#?} ({nullcount} nulls)" + ); + } + }; - expect_eq!( - NullCountTestFilter::new(Some(5), 10).apply_unary(IsNull, col, inverted), - Some(true), - "{col} IS NULL (nullcount: 5, rowcount: 10, inverted: {inverted})" - ); + // no nulls + do_test(0, &[FALSE, TRUE]); - expect_eq!( - NullCountTestFilter::new(Some(10), 10).apply_unary(IsNull, col, inverted), - Some(!inverted), - "{col} IS NULL (nullcount: 10, rowcount: 10, inverted: {inverted})" - ); - } -} + // some nulls + do_test(1, &[TRUE, TRUE]); -#[test] -fn test_bool_col() { - use Scalar::Boolean; - const TRUE: Scalar = Boolean(true); - const FALSE: Scalar = Boolean(false); - for inverted in [false, true] { - expect_eq!( - MinMaxTestFilter::new(TRUE.into(), TRUE.into()) - .apply_column(&column_name!("x"), inverted), - Some(!inverted), - "x as boolean (min: TRUE, max: TRUE, inverted: {inverted})" - ); - expect_eq!( - MinMaxTestFilter::new(FALSE.into(), TRUE.into()) - .apply_column(&column_name!("x"), inverted), - Some(true), - "x as boolean (min: FALSE, max: TRUE, inverted: {inverted})" - ); - expect_eq!( - MinMaxTestFilter::new(FALSE.into(), FALSE.into()) - .apply_column(&column_name!("x"), inverted), - Some(inverted), - "x as boolean (min: FALSE, max: FALSE, inverted: {inverted})" - ); - } + // all nulls + do_test(2, &[TRUE, FALSE]); } struct AllNullTestFilter; -impl ParquetStatsSkippingFilter for AllNullTestFilter { - fn get_min_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { +impl ParquetStatsProvider for AllNullTestFilter { + fn get_parquet_min_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { None } - fn get_max_stat_value(&self, _col: &ColumnName, _data_type: &DataType) -> Option { + fn get_parquet_max_stat(&self, _col: &ColumnName, _data_type: &DataType) -> Option { None } - fn get_nullcount_stat_value(&self, _col: &ColumnName) -> Option { - Some(self.get_rowcount_stat_value()) + fn get_parquet_nullcount_stat(&self, _col: &ColumnName) -> Option { + Some(self.get_parquet_rowcount_stat()) } - fn get_rowcount_stat_value(&self) -> i64 { + fn get_parquet_rowcount_stat(&self) -> i64 { 10 } } @@ -813,121 +282,121 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter { #[test] fn test_sql_where() { let col = &column_expr!("x"); - let val = &Expression::literal(1); - const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN)); - const FALSE: Expression = Expression::Literal(Scalar::Boolean(false)); - const TRUE: Expression = Expression::Literal(Scalar::Boolean(true)); + const VAL: Expr = Expr::Literal(Scalar::Integer(1)); + const NULL: Expr = Expr::Literal(Scalar::Null(DataType::BOOLEAN)); + const FALSE: Expr = Expr::Literal(Scalar::Boolean(false)); + const TRUE: Expr = Expr::Literal(Scalar::Boolean(true)); // Basic sanity checks - expect_eq!(AllNullTestFilter.apply_sql_where(val), None, "WHERE {val}"); - expect_eq!(AllNullTestFilter.apply_sql_where(col), None, "WHERE {col}"); + expect_eq!(AllNullTestFilter.eval_sql_where(&VAL), None, "WHERE {VAL}"); + expect_eq!(AllNullTestFilter.eval_sql_where(col), None, "WHERE {col}"); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::is_null(col.clone())), + AllNullTestFilter.eval_sql_where(&Expr::is_null(col.clone())), Some(true), // No injected NULL checks "WHERE {col} IS NULL" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::lt(TRUE, FALSE)), + AllNullTestFilter.eval_sql_where(&Expr::lt(TRUE, FALSE)), Some(false), // Injected NULL checks don't short circuit when inputs are NOT NULL "WHERE {TRUE} < {FALSE}" ); // Constrast normal vs SQL WHERE semantics - comparison expect_eq!( - AllNullTestFilter.apply_expr(&Expression::lt(col.clone(), val.clone()), false), + AllNullTestFilter.eval_expr(&Expr::lt(col.clone(), VAL), false), None, - "{col} < {val}" + "{col} < {VAL}" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::lt(col.clone(), val.clone())), + AllNullTestFilter.eval_sql_where(&Expr::lt(col.clone(), VAL)), Some(false), - "WHERE {col} < {val}" + "WHERE {col} < {VAL}" ); expect_eq!( - AllNullTestFilter.apply_expr(&Expression::lt(val.clone(), col.clone()), false), + AllNullTestFilter.eval_expr(&Expr::lt(VAL, col.clone()), false), None, - "{val} < {col}" + "{VAL} < {col}" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::lt(val.clone(), col.clone())), + AllNullTestFilter.eval_sql_where(&Expr::lt(VAL, col.clone())), Some(false), - "WHERE {val} < {col}" + "WHERE {VAL} < {col}" ); // Constrast normal vs SQL WHERE semantics - comparison inside AND expect_eq!( - AllNullTestFilter.apply_expr( - &Expression::and(NULL, Expression::lt(col.clone(), val.clone())), + AllNullTestFilter.eval_expr( + &Expr::and(NULL, Expr::lt(col.clone(), VAL)), false ), None, - "{NULL} AND {col} < {val}" + "{NULL} AND {col} < {VAL}" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::and( + AllNullTestFilter.eval_sql_where(&Expr::and( NULL, - Expression::lt(col.clone(), val.clone()), + Expr::lt(col.clone(), VAL), )), Some(false), - "WHERE {NULL} AND {col} < {val}" + "WHERE {NULL} AND {col} < {VAL}" ); expect_eq!( - AllNullTestFilter.apply_expr( - &Expression::and(TRUE, Expression::lt(col.clone(), val.clone())), + AllNullTestFilter.eval_expr( + &Expr::and(TRUE, Expr::lt(col.clone(), VAL)), false ), None, // NULL (from the NULL check) is stronger than TRUE - "{TRUE} AND {col} < {val}" + "{TRUE} AND {col} < {VAL}" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::and( + AllNullTestFilter.eval_sql_where(&Expr::and( TRUE, - Expression::lt(col.clone(), val.clone()), + Expr::lt(col.clone(), VAL), )), Some(false), // FALSE (from the NULL check) is stronger than TRUE - "WHERE {TRUE} AND {col} < {val}" + "WHERE {TRUE} AND {col} < {VAL}" ); // Contrast normal vs. SQL WHERE semantics - comparison inside AND inside AND expect_eq!( - AllNullTestFilter.apply_expr( - &Expression::and( + AllNullTestFilter.eval_expr( + &Expr::and( TRUE, - Expression::and(NULL, Expression::lt(col.clone(), val.clone())), + Expr::and(NULL, Expr::lt(col.clone(), VAL)), ), false, ), None, - "{TRUE} AND ({NULL} AND {col} < {val})" + "{TRUE} AND ({NULL} AND {col} < {VAL})" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::and( + AllNullTestFilter.eval_sql_where(&Expr::and( TRUE, - Expression::and(NULL, Expression::lt(col.clone(), val.clone())), + Expr::and(NULL, Expr::lt(col.clone(), VAL)), )), Some(false), - "WHERE {TRUE} AND ({NULL} AND {col} < {val})" + "WHERE {TRUE} AND ({NULL} AND {col} < {VAL})" ); // Semantics are the same for comparison inside OR inside AND expect_eq!( - AllNullTestFilter.apply_expr( - &Expression::or( + AllNullTestFilter.eval_expr( + &Expr::or( FALSE, - Expression::and(NULL, Expression::lt(col.clone(), val.clone())), + Expr::and(NULL, Expr::lt(col.clone(), VAL)), ), false, ), None, - "{FALSE} OR ({NULL} AND {col} < {val})" + "{FALSE} OR ({NULL} AND {col} < {VAL})" ); expect_eq!( - AllNullTestFilter.apply_sql_where(&Expression::or( + AllNullTestFilter.eval_sql_where(&Expr::or( FALSE, - Expression::and(NULL, Expression::lt(col.clone(), val.clone())), + Expr::and(NULL, Expr::lt(col.clone(), VAL)), )), None, - "WHERE {FALSE} OR ({NULL} AND {col} < {val})" + "WHERE {FALSE} OR ({NULL} AND {col} < {VAL})" ); } diff --git a/kernel/src/expressions/mod.rs b/kernel/src/expressions/mod.rs index e0a274667..f70e94cf3 100644 --- a/kernel/src/expressions/mod.rs +++ b/kernel/src/expressions/mod.rs @@ -54,25 +54,8 @@ impl BinaryOperator { GreaterThanOrEqual => Some(LessThanOrEqual), LessThan => Some(GreaterThan), LessThanOrEqual => Some(GreaterThanOrEqual), - Equal | NotEqual | Plus | Multiply => Some(*self), - _ => None, - } - } - - /// invert an operator. Returns Some if the operator supports inversion, None if it - /// cannot be inverted - pub(crate) fn invert(&self) -> Option { - use BinaryOperator::*; - match self { - LessThan => Some(GreaterThanOrEqual), - LessThanOrEqual => Some(GreaterThan), - GreaterThan => Some(LessThanOrEqual), - GreaterThanOrEqual => Some(LessThan), - Equal => Some(NotEqual), - NotEqual => Some(Equal), - In => Some(NotIn), - NotIn => Some(In), - _ => None, + Equal | NotEqual | Distinct | Plus | Multiply => Some(*self), + In | NotIn | Minus | Divide => None, // not commutative } } } @@ -197,22 +180,14 @@ impl Display for Expression { UnaryOperator::Not => write!(f, "NOT {expr}"), UnaryOperator::IsNull => write!(f, "{expr} IS NULL"), }, - Self::VariadicOperation { op, exprs } => match op { - VariadicOperator::And => { - write!( - f, - "AND({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ) - } - VariadicOperator::Or => { - write!( - f, - "OR({})", - &exprs.iter().map(|e| format!("{e}")).join(", ") - ) - } - }, + Self::VariadicOperation { op, exprs } => { + let exprs = &exprs.iter().map(|e| format!("{e}")).join(", "); + let op = match op { + VariadicOperator::And => "AND", + VariadicOperator::Or => "OR", + }; + write!(f, "{op}({exprs})") + } } } } @@ -356,25 +331,20 @@ impl Expression { } fn walk(&self) -> impl Iterator + '_ { + use Expression::*; let mut stack = vec![self]; std::iter::from_fn(move || { let expr = stack.pop()?; match expr { - Self::Literal(_) => {} - Self::Column { .. } => {} - Self::Struct(exprs) => { - stack.extend(exprs.iter()); - } - Self::BinaryOperation { left, right, .. } => { + Literal(_) => {} + Column { .. } => {} + Struct(exprs) => stack.extend(exprs), + UnaryOperation { expr, .. } => stack.push(expr), + BinaryOperation { left, right, .. } => { stack.push(left); stack.push(right); } - Self::UnaryOperation { expr, .. } => { - stack.push(expr); - } - Self::VariadicOperation { exprs, .. } => { - stack.extend(exprs.iter()); - } + VariadicOperation { exprs, .. } => stack.extend(exprs), } Some(expr) }) diff --git a/kernel/src/lib.rs b/kernel/src/lib.rs index d2a59346c..b792b3789 100644 --- a/kernel/src/lib.rs +++ b/kernel/src/lib.rs @@ -63,6 +63,7 @@ pub mod actions; pub mod engine_data; pub mod error; pub mod expressions; +pub(crate) mod predicates; pub mod table_features; #[cfg(feature = "developer-visibility")] diff --git a/kernel/src/predicates/mod.rs b/kernel/src/predicates/mod.rs new file mode 100644 index 000000000..366daf28f --- /dev/null +++ b/kernel/src/predicates/mod.rs @@ -0,0 +1,643 @@ +use crate::expressions::{ + BinaryOperator, ColumnName, Expression as Expr, Scalar, UnaryOperator, VariadicOperator, +}; +use crate::schema::DataType; + +use std::cmp::Ordering; +use tracing::debug; + +#[cfg(test)] +mod tests; + +/// Evaluates a predicate expression tree against column names that resolve as scalars. Useful for +/// testing/debugging but also serves as a reference implementation that documents the expression +/// semantics that kernel relies on for data skipping. +/// +/// # Inverted expression semantics +/// +/// Because inversion (`NOT` operator) has special semantics and can often be optimized away by +/// pushing it down, most methods take an `inverted` flag. That allows operations like +/// [`UnaryOperator::Not`] to simply evaluate their operand with a flipped `inverted` flag, +/// +/// # NULL and error semantics +/// +/// Literal NULL values almost always produce cascading changes in the predicate's structure, so we +/// represent them by `Option::None` rather than `Scalar::Null`. This allows e.g. `A < NULL` to be +/// rewritten as `NULL`, or `AND(NULL, FALSE)` to be rewritten as `FALSE`. +/// +/// Almost all operations produce NULL output if any input is `NULL`. Any resolution failures also +/// produce NULL (such as missing columns or type mismatch between a column and the scalar it is +/// compared against). NULL-checking operations like `IS [NOT] NULL` and `DISTINCT` are special, and +/// rely on nullcount stats for their work (NULL/missing nullcount stats makes them output NULL). +/// +/// For safety reasons, NULL-checking operations only accept literal and column inputs where +/// stats-based skipping is well-defined. If an arbitrary data skipping expression evaluates to +/// NULL, there is no way to tell whether the original expression really evaluated to NULL (safe to +/// use), or the data skipping version evaluated to NULL due to missing stats (very unsafe to use). +/// +/// NOTE: The error-handling semantics of this trait's scalar-based predicate evaluation may differ +/// from those of the engine's expression evaluation, because kernel expressions don't include the +/// necessary type information to reliably detect all type errors. +pub(crate) trait PredicateEvaluator { + type Output; + + /// A (possibly inverted) boolean scalar value, e.g. `[NOT] `. + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option; + + /// A (possibly inverted) NULL check, e.g. ` IS [NOT] NULL`. + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option; + + /// A less-than comparison, e.g. ` < `. + /// + /// NOTE: Caller is responsible to commute and/or invert the operation if needed, + /// e.g. `NOT( < )` becomes ` <= `. + fn eval_lt(&self, col: &ColumnName, val: &Scalar) -> Option; + + /// A less-than-or-equal comparison, e.g. ` <= ` + /// + /// NOTE: Caller is responsible to commute and/or invert the operation if needed, + /// e.g. `NOT( <= )` becomes ` < `. + fn eval_le(&self, col: &ColumnName, val: &Scalar) -> Option; + + /// A greater-than comparison, e.g. ` > ` + /// + /// NOTE: Caller is responsible to commute and/or invert the operation if needed, + /// e.g. `NOT( > )` becomes ` >= `. + fn eval_gt(&self, col: &ColumnName, val: &Scalar) -> Option; + + /// A greater-than-or-equal comparison, e.g. ` >= ` + /// + /// NOTE: Caller is responsible to commute and/or invert the operation if needed, + /// e.g. `NOT( >= )` becomes ` > `. + fn eval_ge(&self, col: &ColumnName, val: &Scalar) -> Option; + + /// A (possibly inverted) equality comparison, e.g. ` = ` or ` != `. + /// + /// NOTE: Caller is responsible to commute the operation if needed, e.g. ` != ` + /// becomes ` != `. + fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option; + + /// A (possibly inverted) comparison between two scalars, e.g. ` != `. + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option; + + /// A (possibly inverted) comparison between two columns, e.g. ` != `. + fn eval_binary_columns( + &self, + op: BinaryOperator, + a: &ColumnName, + b: &ColumnName, + inverted: bool, + ) -> Option; + + /// Completes evaluation of a (possibly inverted) variadic expression. + /// + /// AND and OR are implemented by first evaluating its (possibly inverted) inputs. This part is + /// always the same, provided by [`eval_variadic`]). The results are then combined to become the + /// expression's output in some implementation-defined way (this method). + fn finish_eval_variadic( + &self, + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option; + + // ==================== PROVIDED METHODS ==================== + + /// A (possibly inverted) boolean column access, e.g. `[NOT] `. + fn eval_column(&self, col: &ColumnName, inverted: bool) -> Option { + // The expression is equivalent to != FALSE, and the expression NOT is + // equivalent to != TRUE. + self.eval_eq(col, &Scalar::from(inverted), true) + } + + /// Dispatches a (possibly inverted) unary expression to each operator's specific implementation. + fn eval_unary(&self, op: UnaryOperator, expr: &Expr, inverted: bool) -> Option { + match op { + UnaryOperator::Not => self.eval_expr(expr, !inverted), + UnaryOperator::IsNull => { + // Data skipping only supports IS [NOT] NULL over columns (not expressions) + let Expr::Column(col) = expr else { + debug!("Unsupported operand: IS [NOT] NULL: {expr:?}"); + return None; + }; + self.eval_is_null(col, inverted) + } + } + } + + /// A (possibly inverted) DISTINCT test, e.g. `[NOT] DISTINCT(, false)`. DISTINCT can be + /// seen as one of two operations, depending on the input: + /// + /// 1. DISTINCT(, NULL) is equivalent to ` IS NOT NULL` + /// 2. DISTINCT(, ) is equivalent to `OR( IS NULL, != )` + fn eval_distinct( + &self, + col: &ColumnName, + val: &Scalar, + inverted: bool, + ) -> Option { + if let Scalar::Null(_) = val { + self.eval_is_null(col, !inverted) + } else { + let args = [ + self.eval_is_null(col, inverted), + self.eval_eq(col, val, !inverted), + ]; + self.finish_eval_variadic(VariadicOperator::Or, args, inverted) + } + } + + /// A (possibly inverted) IN-list check, e.g. ` [NOT] IN `. + /// + /// Unsupported by default, but implementations can override it if they wish. + fn eval_in(&self, _col: &ColumnName, _val: &Scalar, _inverted: bool) -> Option { + None // TODO? + } + + /// Dispatches a (possibly inverted) binary expression to each operator's specific implementation. + /// + /// NOTE: Only binary operators that produce boolean outputs are supported. + fn eval_binary( + &self, + op: BinaryOperator, + left: &Expr, + right: &Expr, + inverted: bool, + ) -> Option { + use BinaryOperator::*; + use Expr::{Column, Literal}; + + // NOTE: We rely on the literal values to provide logical type hints. That means we cannot + // perform column-column comparisons, because we cannot infer the logical type to use. + let (op, col, val) = match (left, right) { + (Column(a), Column(b)) => return self.eval_binary_columns(op, a, b, inverted), + (Literal(a), Literal(b)) => return self.eval_binary_scalars(op, a, b, inverted), + (Literal(val), Column(col)) => (op.commute()?, col, val), + (Column(col), Literal(val)) => (op, col, val), + _ => { + debug!("Unsupported binary operand(s): {left:?} {op:?} {right:?}"); + return None; + } + }; + match (op, inverted) { + (Plus | Minus | Multiply | Divide, _) => None, // Unsupported - not boolean output + (LessThan, false) | (GreaterThanOrEqual, true) => self.eval_lt(col, val), + (LessThanOrEqual, false) | (GreaterThan, true) => self.eval_le(col, val), + (GreaterThan, false) | (LessThanOrEqual, true) => self.eval_gt(col, val), + (GreaterThanOrEqual, false) | (LessThan, true) => self.eval_ge(col, val), + (Equal, _) => self.eval_eq(col, val, inverted), + (NotEqual, _) => self.eval_eq(col, val, !inverted), + (Distinct, _) => self.eval_distinct(col, val, inverted), + (In, _) => self.eval_in(col, val, inverted), + (NotIn, _) => self.eval_in(col, val, !inverted), + } + } + + /// Dispatches a variadic operation, leveraging each implementation's [`finish_eval_variadic`]. + fn eval_variadic( + &self, + op: VariadicOperator, + exprs: &[Expr], + inverted: bool, + ) -> Option { + let exprs = exprs.iter().map(|expr| self.eval_expr(expr, inverted)); + self.finish_eval_variadic(op, exprs, inverted) + } + + /// Dispatches an expression to the specific implementation for each expression variant. + /// + /// NOTE: [`Expression::Struct`] is not supported and always evaluates to `None`. + fn eval_expr(&self, expr: &Expr, inverted: bool) -> Option { + use Expr::*; + match expr { + Literal(val) => self.eval_scalar(val, inverted), + Column(col) => self.eval_column(col, inverted), + Struct(_) => None, // not supported + UnaryOperation { op, expr } => self.eval_unary(*op, expr, inverted), + BinaryOperation { op, left, right } => self.eval_binary(*op, left, right, inverted), + VariadicOperation { op, exprs } => self.eval_variadic(*op, exprs, inverted), + } + } +} + +/// A collection of provided methods from the [`PredicateEvaluator`] trait, factored out to allow +/// reuse by the different predicate evaluator implementations. +pub(crate) struct PredicateEvaluatorDefaults; +impl PredicateEvaluatorDefaults { + /// Directly evaluates a boolean scalar. See [`PredicateEvaluator::eval_scalar`]. + pub(crate) fn eval_scalar(val: &Scalar, inverted: bool) -> Option { + match val { + Scalar::Boolean(val) => Some(*val != inverted), + _ => None, + } + } + + /// A (possibly inverted) partial comparison of two scalars, leveraging the [`PartialOrd`] + /// trait. + pub(crate) fn partial_cmp_scalars( + ord: Ordering, + a: &Scalar, + b: &Scalar, + inverted: bool, + ) -> Option { + let cmp = a.partial_cmp(b)?; + let matched = cmp == ord; + Some(matched != inverted) + } + + /// Directly evaluates a boolean comparison. See [`PredicateEvaluator::eval_binary_scalars`]. + pub(crate) fn eval_binary_scalars( + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option { + use BinaryOperator::*; + match op { + Equal => Self::partial_cmp_scalars(Ordering::Equal, left, right, inverted), + NotEqual => Self::partial_cmp_scalars(Ordering::Equal, left, right, !inverted), + LessThan => Self::partial_cmp_scalars(Ordering::Less, left, right, inverted), + LessThanOrEqual => Self::partial_cmp_scalars(Ordering::Greater, left, right, !inverted), + GreaterThan => Self::partial_cmp_scalars(Ordering::Greater, left, right, inverted), + GreaterThanOrEqual => Self::partial_cmp_scalars(Ordering::Less, left, right, !inverted), + _ => { + debug!("Unsupported binary operator: {left:?} {op:?} {right:?}"); + None + } + } + } + + /// Finishes evaluating a (possibly inverted) variadic operation. See + /// [`PredicateEvaluator::finish_eval_variadic`]. + /// + /// The inputs were already inverted by the caller, if needed. + /// + /// With AND (OR), any FALSE (TRUE) input dominates, forcing a FALSE (TRUE) output. If there + /// was no dominating input, then any NULL input forces NULL output. Otherwise, return the + /// non-dominant value. Inverting the operation also inverts the dominant value. + pub(crate) fn finish_eval_variadic( + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option { + let dominator = match op { + VariadicOperator::And => inverted, + VariadicOperator::Or => !inverted, + }; + let result = exprs.into_iter().try_fold(false, |found_null, val| { + match val { + Some(val) if val == dominator => None, // (1) short circuit, dominant found + Some(_) => Some(found_null), + None => Some(true), // (2) null found (but keep looking for a dominant value) + } + }); + + match result { + None => Some(dominator), // (1) short circuit, dominant found + Some(false) => Some(!dominator), + Some(true) => None, // (2) null found, dominant not found + } + } +} + +/// Resolves columns as scalars, as a building block for [`DefaultPredicateEvaluator`]. +pub(crate) trait ResolveColumnAsScalar { + fn resolve_column(&self, col: &ColumnName) -> Option; +} + +// Some tests do not actually require column resolution +#[cfg(test)] +pub(crate) struct UnimplementedColumnResolver; +#[cfg(test)] +impl ResolveColumnAsScalar for UnimplementedColumnResolver { + fn resolve_column(&self, _col: &ColumnName) -> Option { + unimplemented!() + } +} + +// In testing, it is convenient to just build a hashmap of scalar values. +#[cfg(test)] +impl ResolveColumnAsScalar for std::collections::HashMap { + fn resolve_column(&self, col: &ColumnName) -> Option { + self.get(col).cloned() + } +} + +/// A predicate evaluator that directly evaluates the predicate to produce an `Option` +/// result. Column resolution is handled by an embedded [`ResolveColumnAsScalar`] instance. +pub(crate) struct DefaultPredicateEvaluator { + resolver: R, +} +impl DefaultPredicateEvaluator { + // Convenient thin wrapper + fn resolve_column(&self, col: &ColumnName) -> Option { + self.resolver.resolve_column(col) + } +} + +impl From for DefaultPredicateEvaluator { + fn from(resolver: R) -> Self { + Self { resolver } + } +} + +/// A "normal" predicate evaluator. It takes expressions as input, uses a [`ResolveColumnAsScalar`] +/// to convert column references to scalars, and evaluates the resulting constant expression to +/// produce a boolean output. +impl PredicateEvaluator for DefaultPredicateEvaluator { + type Output = bool; + + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { + PredicateEvaluatorDefaults::eval_scalar(val, inverted) + } + + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + let col = self.resolve_column(col)?; + Some(matches!(col, Scalar::Null(_)) != inverted) + } + + fn eval_lt(&self, col: &ColumnName, val: &Scalar) -> Option { + let col = self.resolve_column(col)?; + self.eval_binary_scalars(BinaryOperator::LessThan, &col, val, false) + } + + fn eval_le(&self, col: &ColumnName, val: &Scalar) -> Option { + let col = self.resolve_column(col)?; + self.eval_binary_scalars(BinaryOperator::LessThanOrEqual, &col, val, false) + } + + fn eval_gt(&self, col: &ColumnName, val: &Scalar) -> Option { + let col = self.resolve_column(col)?; + self.eval_binary_scalars(BinaryOperator::GreaterThan, &col, val, false) + } + + fn eval_ge(&self, col: &ColumnName, val: &Scalar) -> Option { + let col = self.resolve_column(col)?; + self.eval_binary_scalars(BinaryOperator::GreaterThanOrEqual, &col, val, false) + } + + fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + let col = self.resolve_column(col)?; + self.eval_binary_scalars(BinaryOperator::Equal, &col, val, inverted) + } + + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) + } + + fn eval_binary_columns( + &self, + op: BinaryOperator, + left: &ColumnName, + right: &ColumnName, + inverted: bool, + ) -> Option { + let left = self.resolve_column(left)?; + let right = self.resolve_column(right)?; + self.eval_binary_scalars(op, &left, &right, inverted) + } + + fn finish_eval_variadic( + &self, + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::finish_eval_variadic(op, exprs, inverted) + } +} + +/// A predicate evaluator that implements data skipping semantics over various column stats. For +/// example, comparisons involving a column are converted into comparisons over that column's +/// min/max stats, and NULL checks are converted into comparisons involving the column's nullcount +/// and rowcount stats. +/// +/// The types involved in these operations are parameterized and implementation-specific. For +/// example, [`crate::engine::parquet_stats_skipping::ParquetStatsProvider`] directly evaluates data +/// skipping expressions and returnss boolean results, while +/// [`crate::scan::data_skipping::DataSkippingPredicateCreator`] instead converts the input +/// predicate to a data skipping predicate that can be evaluated directly later. +pub(crate) trait DataSkippingPredicateEvaluator { + /// The output type produced by this expression evaluator + type Output; + /// The type of min and max column stats + type TypedStat; + /// The type of nullcount and rowcount column stats + type IntStat; + + /// Retrieves the minimum value of a column, if it exists and has the requested type. + fn get_min_stat(&self, col: &ColumnName, data_type: &DataType) -> Option; + + /// Retrieves the maximum value of a column, if it exists and has the requested type. + fn get_max_stat(&self, col: &ColumnName, data_type: &DataType) -> Option; + + /// Retrieves the null count of a column, if it exists. + fn get_nullcount_stat(&self, col: &ColumnName) -> Option; + + /// Retrieves the row count of a column (parquet footers always include this stat). + fn get_rowcount_stat(&self) -> Option; + + /// See [`PredicateEvaluator::eval_scalar`] + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option; + + /// For IS NULL (IS NOT NULL), we can only skip the file if all-null (no-null). Any other + /// nullcount always forces us to keep the file. + /// + /// NOTE: When deletion vectors are enabled, they could produce a file that is logically + /// all-null or logically no-null, even tho the physical stats indicate a mix of null and + /// non-null values. They cannot invalidate a file's physical all-null or non-null status, + /// however, so the worst that can happen is we fail to skip an unnecessary file. + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option; + + /// See [`PredicateEvaluator::eval_binary_scalars`] + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option; + + /// See [`PredicateEvaluator::finish_eval_variadic`] + fn finish_eval_variadic( + &self, + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option; + + /// Helper method that performs a (possibly inverted) partial comparison between a typed column + /// stat and a scalar. + fn eval_partial_cmp( + &self, + ord: Ordering, + col: Self::TypedStat, + val: &Scalar, + inverted: bool, + ) -> Option; + + /// Performs a partial comparison against a column min-stat. See + /// [`PredicateEvaluatorDefaults::partial_cmp_scalars`] for details of the comparison semantics. + fn partial_cmp_min_stat( + &self, + col: &ColumnName, + val: &Scalar, + ord: Ordering, + inverted: bool, + ) -> Option { + let min = self.get_min_stat(col, &val.data_type())?; + self.eval_partial_cmp(ord, min, val, inverted) + } + + /// Performs a partial comparison against a column max-stat. See + /// [`PredicateEvaluatorDefaults::partial_cmp_scalars`] for details of the comparison semantics. + fn partial_cmp_max_stat( + &self, + col: &ColumnName, + val: &Scalar, + ord: Ordering, + inverted: bool, + ) -> Option { + let max = self.get_max_stat(col, &val.data_type())?; + self.eval_partial_cmp(ord, max, val, inverted) + } + + /// See [`PredicateEvaluator::eval_lt`] + fn eval_lt(&self, col: &ColumnName, val: &Scalar) -> Option { + // Given `col < val`: + // Skip if `val` is not greater than _all_ values in [min, max], implies + // Skip if `val <= min AND val <= max` implies + // Skip if `val <= min` implies + // Keep if `NOT(val <= min)` implies + // Keep if `val > min` implies + // Keep if `min < val` + self.partial_cmp_min_stat(col, val, Ordering::Less, false) + } + + /// See [`PredicateEvaluator::eval_le`] + fn eval_le(&self, col: &ColumnName, val: &Scalar) -> Option { + // Given `col <= val`: + // Skip if `val` is less than _all_ values in [min, max], implies + // Skip if `val < min AND val < max` implies + // Skip if `val < min` implies + // Keep if `NOT(val < min)` implies + // Keep if `NOT(min > val)` + self.partial_cmp_min_stat(col, val, Ordering::Greater, true) + } + + /// See [`PredicateEvaluator::eval_gt`] + fn eval_gt(&self, col: &ColumnName, val: &Scalar) -> Option { + // Given `col > val`: + // Skip if `val` is not less than _all_ values in [min, max], implies + // Skip if `val >= min AND val >= max` implies + // Skip if `val >= max` implies + // Keep if `NOT(val >= max)` implies + // Keep if `NOT(max <= val)` implies + // Keep if `max > val` + self.partial_cmp_max_stat(col, val, Ordering::Greater, false) + } + + /// See [`PredicateEvaluator::eval_ge`] + fn eval_ge(&self, col: &ColumnName, val: &Scalar) -> Option { + // Given `col >= val`: + // Skip if `val is greater than _every_ value in [min, max], implies + // Skip if `val > min AND val > max` implies + // Skip if `val > max` implies + // Keep if `NOT(val > max)` implies + // Keep if `NOT(max < val)` + self.partial_cmp_max_stat(col, val, Ordering::Less, true) + } + + /// See [`PredicateEvaluator::eval_ge`] + fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + let (op, exprs) = if inverted { + // Column could compare not-equal if min or max value differs from the literal. + let exprs = [ + self.partial_cmp_min_stat(col, val, Ordering::Equal, true), + self.partial_cmp_max_stat(col, val, Ordering::Equal, true), + ]; + (VariadicOperator::Or, exprs) + } else { + // Column could compare equal if its min/max values bracket the literal. + let exprs = [ + self.partial_cmp_min_stat(col, val, Ordering::Greater, true), + self.partial_cmp_max_stat(col, val, Ordering::Less, true), + ]; + (VariadicOperator::And, exprs) + }; + self.finish_eval_variadic(op, exprs, false) + } +} + +impl PredicateEvaluator for T { + type Output = T::Output; + + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { + self.eval_scalar(val, inverted) + } + + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + self.eval_is_null(col, inverted) + } + + fn eval_lt(&self, col: &ColumnName, val: &Scalar) -> Option { + self.eval_lt(col, val) + } + + fn eval_le(&self, col: &ColumnName, val: &Scalar) -> Option { + self.eval_le(col, val) + } + + fn eval_gt(&self, col: &ColumnName, val: &Scalar) -> Option { + self.eval_gt(col, val) + } + + fn eval_ge(&self, col: &ColumnName, val: &Scalar) -> Option { + self.eval_ge(col, val) + } + + fn eval_eq(&self, col: &ColumnName, val: &Scalar, inverted: bool) -> Option { + self.eval_eq(col, val, inverted) + } + + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option { + self.eval_binary_scalars(op, left, right, inverted) + } + + fn eval_binary_columns( + &self, + _op: BinaryOperator, + _a: &ColumnName, + _b: &ColumnName, + _inverted: bool, + ) -> Option { + None // Unsupported + } + + fn finish_eval_variadic( + &self, + op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option { + self.finish_eval_variadic(op, exprs, inverted) + } +} diff --git a/kernel/src/predicates/tests.rs b/kernel/src/predicates/tests.rs new file mode 100644 index 000000000..ce273e7b8 --- /dev/null +++ b/kernel/src/predicates/tests.rs @@ -0,0 +1,572 @@ +use super::*; +use crate::expressions::{ + column_expr, column_name, ArrayData, Expression, StructData, UnaryOperator, +}; +use crate::predicates::PredicateEvaluator; +use crate::schema::ArrayType; +use crate::DataType; + +use std::collections::HashMap; + +macro_rules! expect_eq { + ( $expr: expr, $expect: expr, $fmt: literal ) => { + let expect = ($expect); + let result = ($expr); + assert!( + result == expect, + "Expected {} = {:?}, got {:?}", + format!($fmt), + expect, + result + ); + }; +} + +impl ResolveColumnAsScalar for Scalar { + fn resolve_column(&self, _col: &ColumnName) -> Option { + Some(self.clone()) + } +} + +#[test] +fn test_default_eval_scalar() { + let test_cases = [ + (Scalar::Boolean(true), false, Some(true)), + (Scalar::Boolean(true), true, Some(false)), + (Scalar::Boolean(false), false, Some(false)), + (Scalar::Boolean(false), true, Some(true)), + (Scalar::Long(1), false, None), + (Scalar::Long(1), true, None), + (Scalar::Null(DataType::BOOLEAN), false, None), + (Scalar::Null(DataType::BOOLEAN), true, None), + (Scalar::Null(DataType::LONG), false, None), + (Scalar::Null(DataType::LONG), true, None), + ]; + for (value, inverted, expect) in test_cases.into_iter() { + assert_eq!( + PredicateEvaluatorDefaults::eval_scalar(&value, inverted), + expect, + "value: {value:?} inverted: {inverted}" + ); + } +} + +// verifies that partial orderings behave as excpected for all Scalar types +#[test] +fn test_default_partial_cmp_scalars() { + use Ordering::*; + use Scalar::*; + + let smaller_values = &[ + Integer(1), + Long(1), + Short(1), + Byte(1), + Float(1.0), + Double(1.0), + String("1".into()), + Boolean(false), + Timestamp(1), + TimestampNtz(1), + Date(1), + Binary(vec![1]), + Decimal(1, 10, 10), // invalid value, + Null(DataType::LONG), + Struct(StructData::try_new(vec![], vec![]).unwrap()), + Array(ArrayData::new( + ArrayType::new(DataType::LONG, false), + &[] as &[i64], + )), + ]; + let larger_values = &[ + Integer(10), + Long(10), + Short(10), + Byte(10), + Float(10.0), + Double(10.0), + String("10".into()), + Boolean(true), + Timestamp(10), + TimestampNtz(10), + Date(10), + Binary(vec![10]), + Decimal(10, 10, 10), // invalid value + Null(DataType::LONG), + Struct(StructData::try_new(vec![], vec![]).unwrap()), + Array(ArrayData::new( + ArrayType::new(DataType::LONG, false), + &[] as &[i64], + )), + ]; + + // scalars of different types are always incomparable + let compare = PredicateEvaluatorDefaults::partial_cmp_scalars; + for (i, a) in smaller_values.iter().enumerate() { + for b in smaller_values.iter().skip(i + 1) { + for op in [Less, Equal, Greater] { + for inverted in [true, false] { + assert!( + compare(op, a, b, inverted).is_none(), + "{:?} should not be comparable to {:?}", + a.data_type(), + b.data_type() + ); + } + } + } + } + + let expect_if_comparable_type = |s: &_, expect| match s { + Null(_) | Decimal(..) | Struct(_) | Array(_) => None, + _ => Some(expect), + }; + + // Test same-type comparisons where a == b + for (a, b) in smaller_values.iter().zip(smaller_values) { + for inverted in [true, false] { + expect_eq!( + compare(Less, a, b, inverted), + expect_if_comparable_type(a, inverted), + "{a:?} < {b:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Equal, a, b, inverted), + expect_if_comparable_type(a, !inverted), + "{a:?} == {b:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Greater, a, b, inverted), + expect_if_comparable_type(a, inverted), + "{a:?} > {b:?} (inverted: {inverted})" + ); + } + } + + // Test same-type comparisons where a < b + for (a, b) in smaller_values.iter().zip(larger_values) { + for inverted in [true, false] { + expect_eq!( + compare(Less, a, b, inverted), + expect_if_comparable_type(a, !inverted), + "{a:?} < {b:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Equal, a, b, inverted), + expect_if_comparable_type(a, inverted), + "{a:?} == {b:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Greater, a, b, inverted), + expect_if_comparable_type(a, inverted), + "{a:?} < {b:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Less, b, a, inverted), + expect_if_comparable_type(a, inverted), + "{b:?} < {a:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Equal, b, a, inverted), + expect_if_comparable_type(a, inverted), + "{b:?} == {a:?} (inverted: {inverted})" + ); + + expect_eq!( + compare(Greater, b, a, inverted), + expect_if_comparable_type(a, !inverted), + "{b:?} < {a:?} (inverted: {inverted})" + ); + } + } +} + +// Verifies that eval_binary_scalars uses partial_cmp_scalars correctly +#[test] +fn test_eval_binary_scalars() { + use BinaryOperator::*; + let smaller_value = Scalar::Long(1); + let larger_value = Scalar::Long(10); + for inverted in [true, false] { + let compare = PredicateEvaluatorDefaults::eval_binary_scalars; + expect_eq!( + compare(Equal, &smaller_value, &smaller_value, inverted), + Some(!inverted), + "{smaller_value} == {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(Equal, &smaller_value, &larger_value, inverted), + Some(inverted), + "{smaller_value} == {larger_value} (inverted: {inverted})" + ); + + expect_eq!( + compare(NotEqual, &smaller_value, &smaller_value, inverted), + Some(inverted), + "{smaller_value} != {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(NotEqual, &smaller_value, &larger_value, inverted), + Some(!inverted), + "{smaller_value} != {larger_value} (inverted: {inverted})" + ); + + expect_eq!( + compare(LessThan, &smaller_value, &smaller_value, inverted), + Some(inverted), + "{smaller_value} < {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(LessThan, &smaller_value, &larger_value, inverted), + Some(!inverted), + "{smaller_value} < {larger_value} (inverted: {inverted})" + ); + + expect_eq!( + compare(GreaterThan, &smaller_value, &smaller_value, inverted), + Some(inverted), + "{smaller_value} > {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(GreaterThan, &smaller_value, &larger_value, inverted), + Some(inverted), + "{smaller_value} > {larger_value} (inverted: {inverted})" + ); + + expect_eq!( + compare(LessThanOrEqual, &smaller_value, &smaller_value, inverted), + Some(!inverted), + "{smaller_value} <= {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(LessThanOrEqual, &smaller_value, &larger_value, inverted), + Some(!inverted), + "{smaller_value} <= {larger_value} (inverted: {inverted})" + ); + + expect_eq!( + compare(GreaterThanOrEqual, &smaller_value, &smaller_value, inverted), + Some(!inverted), + "{smaller_value} >= {smaller_value} (inverted: {inverted})" + ); + expect_eq!( + compare(GreaterThanOrEqual, &smaller_value, &larger_value, inverted), + Some(inverted), + "{smaller_value} >= {larger_value} (inverted: {inverted})" + ); + } +} + +// NOTE: We're testing routing here -- the actual comparisons are already validated by test_eval_binary_scalars. +#[test] +fn test_eval_binary_columns() { + let columns = HashMap::from_iter(vec![ + (column_name!("x"), Scalar::from(1)), + (column_name!("y"), Scalar::from(10)), + ]); + let filter = DefaultPredicateEvaluator::from(columns); + let x = column_expr!("x"); + let y = column_expr!("y"); + for inverted in [true, false] { + assert_eq!( + filter.eval_binary(BinaryOperator::Equal, &x, &y, inverted), + Some(inverted), + "x = y (inverted: {inverted})" + ); + assert_eq!( + filter.eval_binary(BinaryOperator::Equal, &x, &x, inverted), + Some(!inverted), + "x = x (inverted: {inverted})" + ); + } +} + +#[test] +fn test_eval_variadic() { + let test_cases: Vec<(&[_], _, _)> = vec![ + // input, AND expect, OR expect + (&[], Some(true), Some(false)), + (&[Some(true)], Some(true), Some(true)), + (&[Some(false)], Some(false), Some(false)), + (&[None], None, None), + (&[Some(true), Some(false)], Some(false), Some(true)), + (&[Some(false), Some(true)], Some(false), Some(true)), + (&[Some(true), None], None, Some(true)), + (&[None, Some(true)], None, Some(true)), + (&[Some(false), None], Some(false), None), + (&[None, Some(false)], Some(false), None), + (&[None, Some(false), Some(true)], Some(false), Some(true)), + (&[None, Some(true), Some(false)], Some(false), Some(true)), + (&[Some(false), None, Some(true)], Some(false), Some(true)), + (&[Some(true), None, Some(false)], Some(false), Some(true)), + (&[Some(false), Some(true), None], Some(false), Some(true)), + (&[Some(true), Some(false), None], Some(false), Some(true)), + ]; + let filter = DefaultPredicateEvaluator::from(UnimplementedColumnResolver); + for (inputs, expect_and, expect_or) in test_cases.iter() { + let inputs: Vec<_> = inputs + .iter() + .cloned() + .map(|v| match v { + Some(v) => Expression::literal(v), + None => Expression::null_literal(DataType::BOOLEAN), + }) + .collect(); + for inverted in [true, false] { + let invert_if_needed = |v: &Option<_>| v.map(|v| v != inverted); + expect_eq!( + filter.eval_variadic(VariadicOperator::And, &inputs, inverted), + invert_if_needed(expect_and), + "AND({inputs:?}) (inverted: {inverted})" + ); + expect_eq!( + filter.eval_variadic(VariadicOperator::Or, &inputs, inverted), + invert_if_needed(expect_or), + "OR({inputs:?}) (inverted: {inverted})" + ); + } + } +} + +#[test] +fn test_eval_column() { + let test_cases = [ + (Scalar::from(true), Some(true)), + (Scalar::from(false), Some(false)), + (Scalar::Null(DataType::BOOLEAN), None), + (Scalar::from(1), None), + ]; + let col = &column_name!("x"); + for (input, expect) in &test_cases { + let filter = DefaultPredicateEvaluator::from(input.clone()); + for inverted in [true, false] { + expect_eq!( + filter.eval_column(col, inverted), + expect.map(|v| v != inverted), + "{input:?} (inverted: {inverted})" + ); + } + } +} + +#[test] +fn test_eval_not() { + let test_cases = [ + (Scalar::Boolean(true), Some(false)), + (Scalar::Boolean(false), Some(true)), + (Scalar::Null(DataType::BOOLEAN), None), + (Scalar::Long(1), None), + ]; + let filter = DefaultPredicateEvaluator::from(UnimplementedColumnResolver); + for (input, expect) in test_cases { + let input = input.into(); + for inverted in [true, false] { + expect_eq!( + filter.eval_unary(UnaryOperator::Not, &input, inverted), + expect.map(|v| v != inverted), + "NOT({input:?}) (inverted: {inverted})" + ); + } + } +} + +#[test] +fn test_eval_is_null() { + let expr = column_expr!("x"); + let filter = DefaultPredicateEvaluator::from(Scalar::from(1)); + expect_eq!( + filter.eval_unary(UnaryOperator::IsNull, &expr, true), + Some(true), + "x IS NOT NULL" + ); + expect_eq!( + filter.eval_unary(UnaryOperator::IsNull, &expr, false), + Some(false), + "x IS NULL" + ); + + let expr = Expression::literal(1); + expect_eq!( + filter.eval_unary(UnaryOperator::IsNull, &expr, true), + None, + "1 IS NOT NULL" + ); + expect_eq!( + filter.eval_unary(UnaryOperator::IsNull, &expr, false), + None, + "1 IS NULL" + ); +} + +#[test] +fn test_eval_distinct() { + let one = &Scalar::from(1); + let two = &Scalar::from(2); + let null = &Scalar::Null(DataType::INTEGER); + let filter = DefaultPredicateEvaluator::from(one.clone()); + let col = &column_name!("x"); + expect_eq!( + filter.eval_distinct(col, one, true), + Some(true), + "NOT DISTINCT(x, 1) (x = 1)" + ); + expect_eq!( + filter.eval_distinct(col, one, false), + Some(false), + "DISTINCT(x, 1) (x = 1)" + ); + expect_eq!( + filter.eval_distinct(col, two, true), + Some(false), + "NOT DISTINCT(x, 2) (x = 1)" + ); + expect_eq!( + filter.eval_distinct(col, two, false), + Some(true), + "DISTINCT(x, 2) (x = 1)" + ); + expect_eq!( + filter.eval_distinct(col, null, true), + Some(false), + "NOT DISTINCT(x, NULL) (x = 1)" + ); + expect_eq!( + filter.eval_distinct(col, null, false), + Some(true), + "DISTINCT(x, NULL) (x = 1)" + ); + + let filter = DefaultPredicateEvaluator::from(null.clone()); + expect_eq!( + filter.eval_distinct(col, one, true), + Some(false), + "NOT DISTINCT(x, 1) (x = NULL)" + ); + expect_eq!( + filter.eval_distinct(col, one, false), + Some(true), + "DISTINCT(x, 1) (x = NULL)" + ); + expect_eq!( + filter.eval_distinct(col, null, true), + Some(true), + "NOT DISTINCT(x, NULL) (x = NULL)" + ); + expect_eq!( + filter.eval_distinct(col, null, false), + Some(false), + "DISTINCT(x, NULL) (x = NULL)" + ); +} + +// NOTE: We're testing routing here -- the actual comparisons are already validated by +// test_eval_binary_scalars. +#[test] +fn eval_binary() { + let col = column_expr!("x"); + let val = Expression::literal(10); + let filter = DefaultPredicateEvaluator::from(Scalar::from(1)); + + // unsupported + expect_eq!( + filter.eval_binary(BinaryOperator::Plus, &col, &val, false), + None, + "x + 10" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Minus, &col, &val, false), + None, + "x - 10" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Multiply, &col, &val, false), + None, + "x * 10" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Divide, &col, &val, false), + None, + "x / 10" + ); + + // supported + for inverted in [true, false] { + expect_eq!( + filter.eval_binary(BinaryOperator::LessThan, &col, &val, inverted), + Some(!inverted), + "x < 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::LessThanOrEqual, &col, &val, inverted), + Some(!inverted), + "x <= 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Equal, &col, &val, inverted), + Some(inverted), + "x = 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::NotEqual, &col, &val, inverted), + Some(!inverted), + "x != 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::GreaterThanOrEqual, &col, &val, inverted), + Some(inverted), + "x >= 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::GreaterThan, &col, &val, inverted), + Some(inverted), + "x > 10 (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Distinct, &col, &val, inverted), + Some(!inverted), + "DISTINCT(x, 10) (inverted: {inverted})" + ); + + expect_eq!( + filter.eval_binary(BinaryOperator::LessThan, &val, &col, inverted), + Some(inverted), + "10 < x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::LessThanOrEqual, &val, &col, inverted), + Some(inverted), + "10 <= x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Equal, &val, &col, inverted), + Some(inverted), + "10 = x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::NotEqual, &val, &col, inverted), + Some(!inverted), + "10 != x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::GreaterThanOrEqual, &val, &col, inverted), + Some(!inverted), + "10 >= x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::GreaterThan, &val, &col, inverted), + Some(!inverted), + "10 > x (inverted: {inverted})" + ); + expect_eq!( + filter.eval_binary(BinaryOperator::Distinct, &val, &col, inverted), + Some(!inverted), + "DISTINCT(10, x) (inverted: {inverted})" + ); + } +} diff --git a/kernel/src/scan/data_skipping.rs b/kernel/src/scan/data_skipping.rs index 5a35d2ddd..53845bcfe 100644 --- a/kernel/src/scan/data_skipping.rs +++ b/kernel/src/scan/data_skipping.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::cmp::Ordering; use std::collections::HashSet; use std::sync::{Arc, LazyLock}; @@ -8,86 +9,17 @@ use crate::actions::get_log_add_schema; use crate::actions::visitors::SelectionVectorVisitor; use crate::error::DeltaResult; use crate::expressions::{ - column_expr, column_name, joined_column_expr, BinaryOperator, Expression as Expr, - ExpressionRef, UnaryOperator, VariadicOperator, + column_expr, joined_column_expr, BinaryOperator, ColumnName, Expression as Expr, ExpressionRef, + Scalar, VariadicOperator, +}; +use crate::predicates::{ + DataSkippingPredicateEvaluator, PredicateEvaluator, PredicateEvaluatorDefaults, }; use crate::schema::{DataType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType}; use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler}; -/// Get the expression that checks if a col could be null, assuming tight_bounds = true. In this -/// case a column can contain null if any value > 0 is in the nullCount. This is further complicated -/// by the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_null_expr(null_col: Expr) -> Expr { - Expr::and( - Expr::distinct(column_expr!("tightBounds"), false), - Expr::gt(null_col, 0i64), - ) -} - -/// Get the expression that checks if a col could be null, assuming tight_bounds = false. In this -/// case, we can only check if the WHOLE column is null, by checking if the number of records is -/// equal to the null count, since all other values of nullCount must be ignored (except 0, which -/// doesn't help us) -fn get_wide_null_expr(null_col: Expr) -> Expr { - Expr::and( - Expr::eq(column_expr!("tightBounds"), false), - Expr::eq(column_expr!("numRecords"), null_col), - ) -} - -/// Get the expression that checks if a col could NOT be null, assuming tight_bounds = true. In this -/// case a column has a NOT NULL record if nullCount < numRecords. This is further complicated by -/// the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_not_null_expr(null_col: Expr) -> Expr { - Expr::and( - Expr::distinct(column_expr!("tightBounds"), false), - Expr::lt(null_col, column_expr!("numRecords")), - ) -} - -/// Get the expression that checks if a col could NOT be null, assuming tight_bounds = false. In -/// this case, we can only check if the WHOLE column null, by checking if the nullCount == -/// numRecords. So by inverting that check and seeing if nullCount != numRecords, we can check if -/// there is a possibility of a NOT null -fn get_wide_not_null_expr(null_col: Expr) -> Expr { - Expr::and( - Expr::eq(column_expr!("tightBounds"), false), - Expr::ne(column_expr!("numRecords"), null_col), - ) -} - -/// Use De Morgan's Laws to push a NOT expression down the tree -fn as_inverted_data_skipping_predicate(expr: &Expr) -> Option { - use Expr::*; - match expr { - UnaryOperation { op, expr } => match op { - UnaryOperator::Not => as_data_skipping_predicate(expr), - UnaryOperator::IsNull => { - // to check if a column could NOT have a null, we need two different checks, to see - // if the bounds are tight and then to actually do the check - if let Column(col) = expr.as_ref() { - let null_col = joined_column_expr!("nullCount", col); - Some(Expr::or( - get_tight_not_null_expr(null_col.clone()), - get_wide_not_null_expr(null_col), - )) - } else { - // can't check anything other than a col for null - None - } - } - }, - BinaryOperation { op, left, right } => { - let expr = Expr::binary(op.invert()?, left.as_ref().clone(), right.as_ref().clone()); - as_data_skipping_predicate(&expr) - } - VariadicOperation { op, exprs } => { - let expr = Expr::variadic(op.invert(), exprs.iter().cloned().map(|e| !e)); - as_data_skipping_predicate(&expr) - } - _ => None, - } -} +#[cfg(test)] +mod tests; /// Rewrites a predicate to a predicate that can be used to skip files based on their stats. /// Returns `None` if the predicate is not eligible for data skipping. @@ -105,62 +37,8 @@ fn as_inverted_data_skipping_predicate(expr: &Expr) -> Option { /// are not eligible for data skipping. /// - `OR` is rewritten only if all operands are eligible for data skipping. Otherwise, the whole OR /// expression is dropped. -fn as_data_skipping_predicate(expr: &Expr) -> Option { - use BinaryOperator::*; - use Expr::*; - use UnaryOperator::*; - - match expr { - BinaryOperation { op, left, right } => { - let (op, col, val) = match (left.as_ref(), right.as_ref()) { - (Column(col), Literal(val)) => (*op, col, val), - (Literal(val), Column(col)) => (op.commute()?, col, val), - _ => return None, // unsupported combination of operands - }; - let stats_col = match op { - LessThan | LessThanOrEqual => column_name!("minValues"), - GreaterThan | GreaterThanOrEqual => column_name!("maxValues"), - Equal => { - return as_data_skipping_predicate(&Expr::and( - Expr::le(Column(col.clone()), Literal(val.clone())), - Expr::le(Literal(val.clone()), Column(col.clone())), - )); - } - NotEqual => { - return Some(Expr::or( - Expr::gt(joined_column_expr!("minValues", col), val.clone()), - Expr::lt(joined_column_expr!("maxValues", col), val.clone()), - )); - } - _ => return None, // unsupported operation - }; - Some(Expr::binary(op, stats_col.join(col), val.clone())) - } - // push down Not by inverting everything below it - UnaryOperation { op: Not, expr } => as_inverted_data_skipping_predicate(expr), - UnaryOperation { op: IsNull, expr } => { - // to check if a column could have a null, we need two different checks, to see if - // the bounds are tight and then to actually do the check - if let Column(col) = expr.as_ref() { - let null_col = joined_column_expr!("nullCount", col); - Some(Expr::or( - get_tight_null_expr(null_col.clone()), - get_wide_null_expr(null_col), - )) - } else { - // can't check anything other than a col for null - None - } - } - VariadicOperation { op, exprs } => { - let exprs = exprs.iter().map(as_data_skipping_predicate); - match op { - VariadicOperator::And => Some(Expr::and_from(exprs.flatten())), - VariadicOperator::Or => Some(Expr::or_from(exprs.collect::>>()?)), - } - } - _ => None, - } +fn as_data_skipping_predicate(expr: &Expr, inverted: bool) -> Option { + DataSkippingPredicateCreator.eval_expr(expr, inverted) } pub(crate) struct DataSkippingFilter { @@ -223,7 +101,6 @@ impl DataSkippingFilter { .into_owned(); let stats_schema = Arc::new(StructType::new([ StructField::new("numRecords", DataType::LONG, true), - StructField::new("tightBounds", DataType::BOOLEAN, true), StructField::new("nullCount", nullcount_schema, true), StructField::new("minValues", minmax_schema.clone(), true), StructField::new("maxValues", minmax_schema, true), @@ -249,7 +126,7 @@ impl DataSkippingFilter { let skipping_evaluator = engine.get_expression_handler().get_evaluator( stats_schema.clone(), - Expr::struct_from([as_data_skipping_predicate(predicate)?]), + Expr::struct_from([as_data_skipping_predicate(predicate, false)?]), PREDICATE_SCHEMA.clone(), ); @@ -304,83 +181,100 @@ impl DataSkippingFilter { } } -#[cfg(test)] -mod tests { - use super::*; +struct DataSkippingPredicateCreator; + +impl DataSkippingPredicateEvaluator for DataSkippingPredicateCreator { + type Output = Expr; + type TypedStat = Expr; + type IntStat = Expr; + + /// Retrieves the minimum value of a column, if it exists and has the requested type. + fn get_min_stat(&self, col: &ColumnName, _data_type: &DataType) -> Option { + Some(joined_column_expr!("minValues", col)) + } - #[test] - fn test_rewrite_basic_comparison() { - let column = column_expr!("a"); - let lit_int = Expr::literal(1_i32); - let min_col = column_expr!("minValues.a"); - let max_col = column_expr!("maxValues.a"); + /// Retrieves the maximum value of a column, if it exists and has the requested type. + fn get_max_stat(&self, col: &ColumnName, _data_type: &DataType) -> Option { + Some(joined_column_expr!("maxValues", col)) + } - let cases = [ - ( - column.clone().lt(lit_int.clone()), - Expr::lt(min_col.clone(), lit_int.clone()), - ), - ( - lit_int.clone().lt(column.clone()), - Expr::gt(max_col.clone(), lit_int.clone()), - ), - ( - column.clone().gt(lit_int.clone()), - Expr::gt(max_col.clone(), lit_int.clone()), - ), - ( - lit_int.clone().gt(column.clone()), - Expr::lt(min_col.clone(), lit_int.clone()), - ), - ( - column.clone().lt_eq(lit_int.clone()), - Expr::le(min_col.clone(), lit_int.clone()), - ), - ( - lit_int.clone().lt_eq(column.clone()), - Expr::ge(max_col.clone(), lit_int.clone()), - ), - ( - column.clone().gt_eq(lit_int.clone()), - Expr::ge(max_col.clone(), lit_int.clone()), - ), - ( - lit_int.clone().gt_eq(column.clone()), - Expr::le(min_col.clone(), lit_int.clone()), - ), - ( - column.clone().eq(lit_int.clone()), - Expr::and_from([ - Expr::le(min_col.clone(), lit_int.clone()), - Expr::ge(max_col.clone(), lit_int.clone()), - ]), - ), - ( - lit_int.clone().eq(column.clone()), - Expr::and_from([ - Expr::le(min_col.clone(), lit_int.clone()), - Expr::ge(max_col.clone(), lit_int.clone()), - ]), - ), - ( - column.clone().ne(lit_int.clone()), - Expr::or_from([ - Expr::gt(min_col.clone(), lit_int.clone()), - Expr::lt(max_col.clone(), lit_int.clone()), - ]), - ), - ( - lit_int.clone().ne(column.clone()), - Expr::or_from([ - Expr::gt(min_col.clone(), lit_int.clone()), - Expr::lt(max_col.clone(), lit_int.clone()), - ]), - ), - ]; + /// Retrieves the null count of a column, if it exists. + fn get_nullcount_stat(&self, col: &ColumnName) -> Option { + Some(joined_column_expr!("nullCount", col)) + } + + /// Retrieves the row count of a column (parquet footers always include this stat). + fn get_rowcount_stat(&self) -> Option { + Some(column_expr!("numRecords")) + } - for (input, expected) in cases { - let rewritten = as_data_skipping_predicate(&input).unwrap(); - assert_eq!(rewritten, expected) + fn eval_partial_cmp( + &self, + ord: Ordering, + col: Expr, + val: &Scalar, + inverted: bool, + ) -> Option { + let op = match (ord, inverted) { + (Ordering::Less, false) => BinaryOperator::LessThan, + (Ordering::Less, true) => BinaryOperator::GreaterThanOrEqual, + (Ordering::Equal, false) => BinaryOperator::Equal, + (Ordering::Equal, true) => BinaryOperator::NotEqual, + (Ordering::Greater, false) => BinaryOperator::GreaterThan, + (Ordering::Greater, true) => BinaryOperator::LessThanOrEqual, + }; + Some(Expr::binary(op, col, val.clone())) + } + + fn eval_scalar(&self, val: &Scalar, inverted: bool) -> Option { + PredicateEvaluatorDefaults::eval_scalar(val, inverted).map(Expr::literal) + } + + fn eval_is_null(&self, col: &ColumnName, inverted: bool) -> Option { + let safe_to_skip = match inverted { + true => self.get_rowcount_stat()?, // all-null + false => Expr::literal(0i64), // no-null + }; + Some(Expr::ne(self.get_nullcount_stat(col)?, safe_to_skip)) + } + + fn eval_binary_scalars( + &self, + op: BinaryOperator, + left: &Scalar, + right: &Scalar, + inverted: bool, + ) -> Option { + PredicateEvaluatorDefaults::eval_binary_scalars(op, left, right, inverted) + .map(Expr::literal) + } + + fn finish_eval_variadic( + &self, + mut op: VariadicOperator, + exprs: impl IntoIterator>, + inverted: bool, + ) -> Option { + if inverted { + op = op.invert(); } + // NOTE: We can potentially see a LOT of NULL inputs in a big WHERE clause with lots of + // unsupported data skipping operations. We can't "just" flatten them all away for AND, + // because that could produce TRUE where NULL would otherwise be expected. Similarly, we + // don't want to "just" try_collect inputs for OR, because that can cause OR to produce NULL + // where FALSE would otherwise be expected. So, we filter out all nulls except the first, + // observing that one NULL is enough to produce the correct behavior during predicate eval. + let mut keep_null = true; + let exprs: Vec<_> = exprs + .into_iter() + .flat_map(|e| match e { + Some(expr) => Some(expr), + None => keep_null.then(|| { + keep_null = false; + Expr::null_literal(DataType::BOOLEAN) + }), + }) + .collect(); + Some(Expr::variadic(op, exprs)) } } diff --git a/kernel/src/scan/data_skipping/tests.rs b/kernel/src/scan/data_skipping/tests.rs new file mode 100644 index 000000000..e12adb526 --- /dev/null +++ b/kernel/src/scan/data_skipping/tests.rs @@ -0,0 +1,254 @@ +use super::*; + +use crate::expressions::column_name; +use crate::predicates::{DefaultPredicateEvaluator, UnimplementedColumnResolver}; +use std::collections::HashMap; + +const TRUE: Option = Some(true); +const FALSE: Option = Some(false); +const NULL: Option = None; + +macro_rules! expect_eq { + ( $expr: expr, $expect: expr, $fmt: literal ) => { + let expect = ($expect); + let result = ($expr); + assert!( + result == expect, + "Expected {} = {:?}, got {:?}", + format!($fmt), + expect, + result + ); + }; +} + +#[test] +fn test_eval_is_null() { + let col = &column_expr!("x"); + let expressions = [Expr::is_null(col.clone()), !Expr::is_null(col.clone())]; + + let do_test = |nullcount: i64, expected: &[Option]| { + let resolver = HashMap::from_iter([ + (column_name!("numRecords"), Scalar::from(2i64)), + (column_name!("nullCount.x"), Scalar::from(nullcount)), + ]); + let filter = DefaultPredicateEvaluator::from(resolver); + for (expr, expect) in expressions.iter().zip(expected) { + let pred = as_data_skipping_predicate(expr, false).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + *expect, + "{expr:#?} became {pred:#?} ({nullcount} nulls)" + ); + } + }; + + // no nulls + do_test(0, &[FALSE, TRUE]); + + // some nulls + do_test(1, &[TRUE, TRUE]); + + // all nulls + do_test(2, &[TRUE, FALSE]); +} + +#[test] +fn test_eval_binary_comparisons() { + let col = &column_expr!("x"); + let five = &Scalar::from(5); + let ten = &Scalar::from(10); + let fifteen = &Scalar::from(15); + let null = &Scalar::Null(DataType::INTEGER); + + let expressions = [ + Expr::lt(col.clone(), ten.clone()), + Expr::le(col.clone(), ten.clone()), + Expr::eq(col.clone(), ten.clone()), + Expr::ne(col.clone(), ten.clone()), + Expr::gt(col.clone(), ten.clone()), + Expr::ge(col.clone(), ten.clone()), + ]; + + let do_test = |min: &Scalar, max: &Scalar, expected: &[Option]| { + let resolver = HashMap::from_iter([ + (column_name!("minValues.x"), min.clone()), + (column_name!("maxValues.x"), max.clone()), + ]); + let filter = DefaultPredicateEvaluator::from(resolver); + for (expr, expect) in expressions.iter().zip(expected.iter()) { + let pred = as_data_skipping_predicate(expr, false).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + *expect, + "{expr:#?} became {pred:#?} with [{min}..{max}]" + ); + } + }; + + // value < min = max (15..15 = 10, 15..15 <= 10, etc) + do_test(fifteen, fifteen, &[FALSE, FALSE, FALSE, TRUE, TRUE, TRUE]); + + // min = max = value (10..10 = 10, 10..10 <= 10, etc) + // + // NOTE: missing min or max stat produces NULL output if the expression needed it. + do_test(ten, ten, &[FALSE, TRUE, TRUE, FALSE, FALSE, TRUE]); + do_test(null, ten, &[NULL, NULL, NULL, NULL, FALSE, TRUE]); + do_test(ten, null, &[FALSE, TRUE, NULL, NULL, NULL, NULL]); + + // min = max < value (5..5 = 10, 5..5 <= 10, etc) + do_test(five, five, &[TRUE, TRUE, FALSE, TRUE, FALSE, FALSE]); + + // value = min < max (5..15 = 10, 5..15 <= 10, etc) + do_test(ten, fifteen, &[FALSE, TRUE, TRUE, TRUE, TRUE, TRUE]); + + // min < value < max (5..15 = 10, 5..15 <= 10, etc) + do_test(five, fifteen, &[TRUE, TRUE, TRUE, TRUE, TRUE, TRUE]); +} + +#[test] +fn test_eval_variadic() { + let test_cases = &[ + (&[] as &[Option], TRUE, FALSE), + (&[TRUE], TRUE, TRUE), + (&[FALSE], FALSE, FALSE), + (&[NULL], NULL, NULL), + (&[TRUE, TRUE], TRUE, TRUE), + (&[TRUE, FALSE], FALSE, TRUE), + (&[TRUE, NULL], NULL, TRUE), + (&[FALSE, TRUE], FALSE, TRUE), + (&[FALSE, FALSE], FALSE, FALSE), + (&[FALSE, NULL], FALSE, NULL), + (&[NULL, TRUE], NULL, TRUE), + (&[NULL, FALSE], FALSE, NULL), + (&[NULL, NULL], NULL, NULL), + // Every combo of 1:2 + (&[TRUE, FALSE, FALSE], FALSE, TRUE), + (&[FALSE, TRUE, FALSE], FALSE, TRUE), + (&[FALSE, FALSE, TRUE], FALSE, TRUE), + (&[TRUE, NULL, NULL], NULL, TRUE), + (&[NULL, TRUE, NULL], NULL, TRUE), + (&[NULL, NULL, TRUE], NULL, TRUE), + (&[FALSE, TRUE, TRUE], FALSE, TRUE), + (&[TRUE, FALSE, TRUE], FALSE, TRUE), + (&[TRUE, TRUE, FALSE], FALSE, TRUE), + (&[FALSE, NULL, NULL], FALSE, NULL), + (&[NULL, FALSE, NULL], FALSE, NULL), + (&[NULL, NULL, FALSE], FALSE, NULL), + (&[NULL, TRUE, TRUE], NULL, TRUE), + (&[TRUE, NULL, TRUE], NULL, TRUE), + (&[TRUE, TRUE, NULL], NULL, TRUE), + (&[NULL, FALSE, FALSE], FALSE, NULL), + (&[FALSE, NULL, FALSE], FALSE, NULL), + (&[FALSE, FALSE, NULL], FALSE, NULL), + // Every unique ordering of 3 + (&[TRUE, FALSE, NULL], FALSE, TRUE), + (&[TRUE, NULL, FALSE], FALSE, TRUE), + (&[FALSE, TRUE, NULL], FALSE, TRUE), + (&[FALSE, NULL, TRUE], FALSE, TRUE), + (&[NULL, TRUE, FALSE], FALSE, TRUE), + (&[NULL, FALSE, TRUE], FALSE, TRUE), + ]; + let filter = DefaultPredicateEvaluator::from(UnimplementedColumnResolver); + for (inputs, expect_and, expect_or) in test_cases { + let inputs: Vec<_> = inputs + .iter() + .map(|val| match val { + Some(v) => Expr::literal(v), + None => Expr::null_literal(DataType::BOOLEAN), + }) + .collect(); + + let expr = Expr::and_from(inputs.clone()); + let pred = as_data_skipping_predicate(&expr, false).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + *expect_and, + "AND({inputs:?})" + ); + + let expr = Expr::or_from(inputs.clone()); + let pred = as_data_skipping_predicate(&expr, false).unwrap(); + expect_eq!(filter.eval_expr(&pred, false), *expect_or, "OR({inputs:?})"); + + let expr = Expr::and_from(inputs.clone()); + let pred = as_data_skipping_predicate(&expr, true).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + expect_and.map(|val| !val), + "NOT AND({inputs:?})" + ); + + let expr = Expr::or_from(inputs.clone()); + let pred = as_data_skipping_predicate(&expr, true).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + expect_or.map(|val| !val), + "NOT OR({inputs:?})" + ); + } +} + +// DISTINCT is actually quite complex internally. It indirectly exercises IS [NOT] NULL and +// AND/OR. A different test validates min/max comparisons, so here we're mostly worried about NULL +// vs. non-NULL literals and nullcount/rowcount stats. +#[test] +fn test_eval_distinct() { + let col = &column_expr!("x"); + let five = &Scalar::from(5); + let ten = &Scalar::from(10); + let fifteen = &Scalar::from(15); + let null = &Scalar::Null(DataType::INTEGER); + + let expressions = [ + Expr::distinct(col.clone(), ten.clone()), + !Expr::distinct(col.clone(), ten.clone()), + Expr::distinct(col.clone(), null.clone()), + !Expr::distinct(col.clone(), null.clone()), + ]; + + let do_test = |min: &Scalar, max: &Scalar, nullcount: i64, expected: &[Option]| { + let resolver = HashMap::from_iter([ + (column_name!("numRecords"), Scalar::from(2i64)), + (column_name!("nullCount.x"), Scalar::from(nullcount)), + (column_name!("minValues.x"), min.clone()), + (column_name!("maxValues.x"), max.clone()), + ]); + let filter = DefaultPredicateEvaluator::from(resolver); + for (expr, expect) in expressions.iter().zip(expected) { + let pred = as_data_skipping_predicate(expr, false).unwrap(); + expect_eq!( + filter.eval_expr(&pred, false), + *expect, + "{expr:#?} became {pred:#?} ({min}..{max}, {nullcount} nulls)" + ); + } + }; + + // min = max = value, no nulls + do_test(ten, ten, 0, &[FALSE, TRUE, TRUE, FALSE]); + + // min = max = value, some nulls + do_test(ten, ten, 1, &[TRUE, TRUE, TRUE, TRUE]); + + // min = max = value, all nulls + do_test(ten, ten, 2, &[TRUE, FALSE, FALSE, TRUE]); + + // value < min = max, no nulls + do_test(fifteen, fifteen, 0, &[TRUE, FALSE, TRUE, FALSE]); + + // value < min = max, some nulls + do_test(fifteen, fifteen, 1, &[TRUE, FALSE, TRUE, TRUE]); + + // value < min = max, all nulls + do_test(fifteen, fifteen, 2, &[TRUE, FALSE, FALSE, TRUE]); + + // min < value < max, no nulls + do_test(five, fifteen, 0, &[TRUE, TRUE, TRUE, FALSE]); + + // min < value < max, some nulls + do_test(five, fifteen, 1, &[TRUE, TRUE, TRUE, TRUE]); + + // min < value < max, all nulls + do_test(five, fifteen, 2, &[TRUE, FALSE, FALSE, TRUE]); +} diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index a2af1f53d..02024a79c 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -13,7 +13,7 @@ use delta_kernel::engine::default::DefaultEngine; use delta_kernel::expressions::{column_expr, BinaryOperator, Expression}; use delta_kernel::scan::state::{visit_scan_files, DvInfo, Stats}; use delta_kernel::scan::{transform_to_logical, Scan}; -use delta_kernel::schema::Schema; +use delta_kernel::schema::{DataType, Schema}; use delta_kernel::{Engine, FileMeta, Table}; use object_store::{memory::InMemory, path::Path, ObjectStore}; use test_utils::{ @@ -289,11 +289,11 @@ async fn stats() -> Result<(), Box> { (GreaterThanOrEqual, 7, vec![&batch2]), (GreaterThanOrEqual, 8, vec![]), (NotEqual, 0, vec![&batch2, &batch1]), - (NotEqual, 1, vec![&batch2]), - (NotEqual, 3, vec![&batch2]), + (NotEqual, 1, vec![&batch2, &batch1]), + (NotEqual, 3, vec![&batch2, &batch1]), (NotEqual, 4, vec![&batch2, &batch1]), - (NotEqual, 5, vec![&batch1]), - (NotEqual, 7, vec![&batch1]), + (NotEqual, 5, vec![&batch2, &batch1]), + (NotEqual, 7, vec![&batch2, &batch1]), (NotEqual, 8, vec![&batch2, &batch1]), ]; for (op, value, expected_batches) in test_cases { @@ -301,7 +301,7 @@ async fn stats() -> Result<(), Box> { let scan = snapshot .clone() .scan_builder() - .with_predicate(Arc::new(predicate)) + .with_predicate(Arc::new(predicate.clone())) .build()?; let expected_files = expected_batches.len(); @@ -313,7 +313,7 @@ async fn stats() -> Result<(), Box> { files_scanned += 1; assert_eq!(into_record_batch(raw_data), expected.clone()); } - assert_eq!(expected_files, files_scanned); + assert_eq!(expected_files, files_scanned, "{predicate:?}"); } Ok(()) } @@ -850,28 +850,39 @@ fn not_and_or_predicates() -> Result<(), Box> { fn invalid_skips_none_predicates() -> Result<(), Box> { let empty_struct = Expression::struct_from(vec![]); let cases = vec![ + (Expression::literal(false), table_for_numbers(vec![])), + ( + Expression::literal(true), + table_for_numbers(vec![1, 2, 3, 4, 5, 6]), + ), ( Expression::literal(3i64), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( column_expr!("number").distinct(3i64), - table_for_numbers(vec![1, 2, 3, 4, 5, 6]), + table_for_numbers(vec![1, 2, 4, 5, 6]), ), ( - column_expr!("number").gt(empty_struct.clone()), + column_expr!("number").distinct(Expression::null_literal(DataType::LONG)), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - column_expr!("number").and(empty_struct.clone().is_null()), - table_for_numbers(vec![1, 2, 3, 4, 5, 6]), + Expression::not(column_expr!("number").distinct(3i64)), + table_for_numbers(vec![3]), ), ( - Expression::not(column_expr!("number").gt(empty_struct.clone())), + Expression::not( + column_expr!("number").distinct(Expression::null_literal(DataType::LONG)), + ), + table_for_numbers(vec![]), + ), + ( + column_expr!("number").gt(empty_struct.clone()), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(column_expr!("number").and(empty_struct.clone().is_null())), + Expression::not(column_expr!("number").gt(empty_struct.clone())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ];