From e48d2386ecce55739a3dad005ec70bb218d3b0b2 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Wed, 23 Oct 2024 14:14:20 -0600 Subject: [PATCH] Expression::Column references a ColumnName object instead of a String (#400) Kernel only partly supports nested column names today. One of the biggest barriers to closing the gap is that Expression tracks column names as simple strings, so there is no reliable way to deal with nesting. This PR takes a first step of defining a proper `ColumnName` struct for hosting column names. So far, it behaves the same as before -- internally it's just a simple string, occasionally interpreted as nested by splitting at periods. Changing the code to use this new construct is noisy, so we make the changes as a pre-factor. A follow-up PR will actually add nesting support. Resolves https://github.com/delta-incubator/delta-kernel-rs/issues/422 --- derive-macros/src/lib.rs | 20 ++ ffi/src/expressions.rs | 6 +- kernel/src/actions/set_transaction.rs | 5 +- kernel/src/engine/arrow_expression.rs | 31 +- .../parquet_row_group_skipping/tests.rs | 31 +- .../engine/parquet_stats_skipping/tests.rs | 12 +- kernel/src/expressions/column_names.rs | 269 ++++++++++++++++++ kernel/src/expressions/mod.rs | 37 +-- kernel/src/expressions/scalars.rs | 4 +- kernel/src/scan/data_skipping.rs | 54 ++-- kernel/src/scan/log_replay.rs | 14 +- kernel/src/scan/mod.rs | 12 +- kernel/src/snapshot.rs | 5 +- kernel/tests/read.rs | 85 +++--- 14 files changed, 441 insertions(+), 144 deletions(-) create mode 100644 kernel/src/expressions/column_names.rs diff --git a/derive-macros/src/lib.rs b/derive-macros/src/lib.rs index 9bf74690b..3b2e35aa4 100644 --- a/derive-macros/src/lib.rs +++ b/derive-macros/src/lib.rs @@ -5,6 +5,26 @@ use syn::{ parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Meta, PathArguments, Type, }; +/// Parses a dot-delimited column name into an array of field names. See +/// [`delta_kernel::expressions::column_name::column_name`] macro for details. +#[proc_macro] +pub fn parse_column_name(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let is_valid = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.'; + let err = match syn::parse(input) { + Ok(syn::Lit::Str(name)) => match name.value().chars().find(|c| !is_valid(*c)) { + Some(bad_char) => Error::new(name.span(), format!("Invalid character: {bad_char:?}")), + _ => { + let path = name.value(); + let path = path.split('.').map(proc_macro2::Literal::string); + return quote_spanned! { name.span() => [#(#path),*] }.into(); + } + }, + Ok(lit) => Error::new(lit.span(), "Expected a string literal"), + Err(err) => err, + }; + err.into_compile_error().into() +} + /// Derive a `delta_kernel::schemas::ToDataType` implementation for the annotated struct. The actual /// field names in the schema (and therefore of the struct members) are all mandated by the Delta /// spec, and so the user of this macro is responsible for ensuring that diff --git a/ffi/src/expressions.rs b/ffi/src/expressions.rs index 087cae163..19c334038 100644 --- a/ffi/src/expressions.rs +++ b/ffi/src/expressions.rs @@ -5,7 +5,7 @@ use crate::{ ReferenceSet, TryFromStringSlice, }; use delta_kernel::{ - expressions::{BinaryOperator, Expression, UnaryOperator}, + expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator}, DeltaResult, }; @@ -146,7 +146,9 @@ fn visit_expression_column_impl( state: &mut KernelExpressionVisitorState, name: DeltaResult, ) -> DeltaResult { - Ok(wrap_expression(state, Expression::Column(name?))) + // TODO: FIXME: This is incorrect if any field name in the column path contains a period. + let name = ColumnName::new(name?.split('.')).into(); + Ok(wrap_expression(state, name)) } #[no_mangle] diff --git a/kernel/src/actions/set_transaction.rs b/kernel/src/actions/set_transaction.rs index 5dcba4323..5cfae8863 100644 --- a/kernel/src/actions/set_transaction.rs +++ b/kernel/src/actions/set_transaction.rs @@ -2,8 +2,9 @@ use std::sync::{Arc, LazyLock}; use crate::actions::visitors::SetTransactionVisitor; use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME}; +use crate::expressions::column_expr; use crate::snapshot::Snapshot; -use crate::{DeltaResult, Engine, EngineData, Expression, ExpressionRef, SchemaRef}; +use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef}; pub use crate::actions::visitors::SetTransactionMap; pub struct SetTransactionScanner { @@ -53,7 +54,7 @@ impl SetTransactionScanner { // point filtering by a particular app id, even if we have one, because app ids are all in // the a single checkpoint part having large min/max range (because they're usually uuids). static META_PREDICATE: LazyLock> = - LazyLock::new(|| Some(Arc::new(Expression::column("txn.appId").is_not_null()))); + LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null()))); self.snapshot .log_segment .replay(engine, schema.clone(), schema, META_PREDICATE.clone()) diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index fa87cb8bb..cb9138299 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -589,9 +589,9 @@ mod tests { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item")); + let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); - let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item")); + let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item")); let result = evaluate_expression(¬_op, &batch, None).unwrap(); let expected = BooleanArray::from(vec![true, false, true]); @@ -609,7 +609,7 @@ mod tests { let schema = Schema::new([field.clone()]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); - let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item")); + let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); let in_result = evaluate_expression(&in_op, &batch, None); @@ -654,8 +654,8 @@ mod tests { let in_op = Expression::binary( BinaryOperator::NotIn, - Expression::column("item"), - Expression::column("item"), + column_expr!("item"), + column_expr!("item"), ); let in_result = evaluate_expression(&in_op, &batch, None); @@ -679,10 +679,9 @@ mod tests { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let str_not_op = - Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item")); + let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item")); - let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item")); + let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item")); let result = evaluate_expression(&str_in_op, &batch, None).unwrap(); let expected = BooleanArray::from(vec![true, true, true]); @@ -699,7 +698,7 @@ mod tests { let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let results = evaluate_expression(&column, &batch, None).unwrap(); assert_eq!(results.as_ref(), &values); @@ -720,7 +719,7 @@ mod tests { vec![Arc::new(struct_array.clone())], ) .unwrap(); - let column = Expression::column("b.a"); + let column = column_expr!("b.a"); let results = evaluate_expression(&column, &batch, None).unwrap(); assert_eq!(results.as_ref(), &values); } @@ -730,7 +729,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let expression = column.clone().add(1); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -766,8 +765,8 @@ mod tests { vec![Arc::new(values.clone()), Arc::new(values)], ) .unwrap(); - let column_a = Expression::column("a"); - let column_b = Expression::column("b"); + let column_a = column_expr!("a"); + let column_b = column_expr!("b"); let expression = column_a.clone().add(column_b.clone()); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -790,7 +789,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let expression = column.clone().lt(2); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -837,8 +836,8 @@ mod tests { ], ) .unwrap(); - let column_a = Expression::column("a"); - let column_b = Expression::column("b"); + let column_a = column_expr!("a"); + let column_b = column_expr!("b"); let expression = column_a.clone().and(column_b.clone()); let results = diff --git a/kernel/src/engine/parquet_row_group_skipping/tests.rs b/kernel/src/engine/parquet_row_group_skipping/tests.rs index 6f5dd3a48..19bd2b5bf 100644 --- a/kernel/src/engine/parquet_row_group_skipping/tests.rs +++ b/kernel/src/engine/parquet_row_group_skipping/tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::expressions::column_expr; use crate::Expression; use parquet::arrow::arrow_reader::ArrowReaderMetadata; use std::fs::File; @@ -39,21 +40,21 @@ fn test_get_stat_values() { // The expression doesn't matter -- it just needs to mention all the columns we care about. let columns = Expression::and_from(vec![ - Expression::column("varlen.utf8"), - Expression::column("numeric.ints.int64"), - Expression::column("numeric.ints.int32"), - Expression::column("numeric.ints.int16"), - Expression::column("numeric.ints.int8"), - Expression::column("numeric.floats.float32"), - Expression::column("numeric.floats.float64"), - Expression::column("bool"), - Expression::column("varlen.binary"), - Expression::column("numeric.decimals.decimal32"), - Expression::column("numeric.decimals.decimal64"), - Expression::column("numeric.decimals.decimal128"), - Expression::column("chrono.date32"), - Expression::column("chrono.timestamp"), - Expression::column("chrono.timestamp_ntz"), + column_expr!("varlen.utf8"), + column_expr!("numeric.ints.int64"), + column_expr!("numeric.ints.int32"), + column_expr!("numeric.ints.int16"), + column_expr!("numeric.ints.int8"), + column_expr!("numeric.floats.float32"), + column_expr!("numeric.floats.float64"), + column_expr!("bool"), + column_expr!("varlen.binary"), + column_expr!("numeric.decimals.decimal32"), + column_expr!("numeric.decimals.decimal64"), + column_expr!("numeric.decimals.decimal128"), + column_expr!("chrono.date32"), + column_expr!("chrono.timestamp"), + column_expr!("chrono.timestamp_ntz"), ]); let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns); diff --git a/kernel/src/engine/parquet_stats_skipping/tests.rs b/kernel/src/engine/parquet_stats_skipping/tests.rs index a95ac4102..b4bdd97d3 100644 --- a/kernel/src/engine/parquet_stats_skipping/tests.rs +++ b/kernel/src/engine/parquet_stats_skipping/tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::expressions::{ArrayData, StructData}; +use crate::expressions::{column_expr, ArrayData, StructData}; use crate::schema::ArrayType; use crate::DataType; @@ -337,7 +337,7 @@ fn test_binary_eq_ne() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { // negative test -- mismatched column type @@ -485,7 +485,7 @@ fn test_binary_lt_ge() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { expect_eq!( @@ -585,7 +585,7 @@ fn test_binary_le_gt() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { // negative test -- mismatched column type @@ -736,7 +736,7 @@ impl ParquetStatsSkippingFilter for NullCountTestFilter { fn test_not_null() { use UnaryOperator::IsNull; - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { expect_eq!( NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted), @@ -809,7 +809,7 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter { #[test] fn test_sql_where() { - let col = &Expression::column("x"); + let col = &column_expr!("x"); let val = &Expression::literal(1); const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN)); const FALSE: Expression = Expression::Literal(Scalar::Boolean(false)); diff --git a/kernel/src/expressions/column_names.rs b/kernel/src/expressions/column_names.rs new file mode 100644 index 000000000..4129abb30 --- /dev/null +++ b/kernel/src/expressions/column_names.rs @@ -0,0 +1,269 @@ +use std::borrow::Borrow; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +/// A (possibly nested) column name. +// TODO: Track name as a path rather than a single string +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)] +pub struct ColumnName { + path: String, +} + +impl ColumnName { + /// Constructs a new column name from an iterator of field names. The field names are joined + /// together to make a single path. + pub fn new(path: impl IntoIterator>) -> Self { + let path: Vec<_> = path.into_iter().map(Into::into).collect(); + let path = path.join("."); + Self { path } + } + + /// Joins this column with another, concatenating their fields into a single nested column path. + /// + /// NOTE: This is a convenience method that copies two arguments without consuming them. If more + /// arguments are needed, or if performance is a concern, it is recommended to use + /// [`FromIterator for ColumnName`](#impl-FromIterator-for-ColumnName) instead: + /// + /// ``` + /// # use delta_kernel::expressions::ColumnName; + /// let x = ColumnName::new(["a", "b"]); + /// let y = ColumnName::new(["c", "d"]); + /// let joined: ColumnName = [x, y].into_iter().collect(); + /// assert_eq!(joined, ColumnName::new(["a", "b", "c", "d"])); + /// ``` + pub fn join(&self, right: &ColumnName) -> ColumnName { + [self.clone(), right.clone()].into_iter().collect() + } + + /// The path of field names for this column name + pub fn path(&self) -> &String { + &self.path + } + + /// Consumes this column name and returns the path of field names. + pub fn into_inner(self) -> String { + self.path + } +} + +/// Creates a new column name from a path of field names. Each field name is taken as-is, and may +/// contain arbitrary characters (including periods, spaces, etc.). +impl> FromIterator for ColumnName { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self::new(iter) + } +} + +/// Creates a new column name by joining multiple column names together. +impl FromIterator for ColumnName { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self::new(iter.into_iter().map(ColumnName::into_inner)) + } +} + +impl Display for ColumnName { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + (**self).fmt(f) + } +} + +impl Deref for ColumnName { + type Target = String; + + fn deref(&self) -> &String { + &self.path + } +} + +// Allows searching collections of `ColumnName` without an owned key value +impl Borrow for ColumnName { + fn borrow(&self) -> &String { + self + } +} + +// Allows searching collections of `&ColumnName` without an owned key value. Needed because there is +// apparently no blanket `impl Borrow for &T where T: Borrow`, even tho `Eq` [1] and +// `Hash` [2] both have blanket impl for treating `&T` like `T`. +// +// [1] https://doc.rust-lang.org/std/cmp/trait.Eq.html#impl-Eq-for-%26A +// [2] https://doc.rust-lang.org/std/hash/trait.Hash.html#impl-Hash-for-%26T +impl Borrow for &ColumnName { + fn borrow(&self) -> &String { + self + } +} + +impl Hash for ColumnName { + fn hash(&self, hasher: &mut H) { + (**self).hash(hasher) + } +} + +/// Creates a nested column name whose field names are all simple column names (containing only +/// alphanumeric characters and underscores), delimited by dots. This macro is provided as a +/// convenience for the common case where the caller knows the column name contains only simple +/// field names and that splitting by periods is safe: +/// +/// ``` +/// # use delta_kernel::expressions::{column_name, ColumnName}; +/// assert_eq!(column_name!("a.b.c"), ColumnName::new(["a", "b", "c"])); +/// ``` +/// +/// To avoid accidental misuse, the argument must be a string literal, so the compiler can validate +/// the safety conditions. Thus, the following uses would fail to compile: +/// +/// ```fail_compile +/// # use delta_kernel::expressions::column_name; +/// let s = "a.b"; +/// let name = column_name!(s); // not a string literal +/// ``` +/// +/// ```fail_compile +/// # use delta_kernel::expressions::simple_column_name; +/// let name = simple_column_name!("a b"); // non-alphanumeric character +/// ``` +// NOTE: Macros are only public if exported, which defines them at the root of the crate. But we +// don't want it there. So, we export a hidden macro and pub use it here where we actually want it. +#[macro_export] +#[doc(hidden)] +macro_rules! __column_name { + ( $($name:tt)* ) => { + $crate::expressions::ColumnName::new(delta_kernel_derive::parse_column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __column_name as column_name; + +/// Joins two column names together, when one or both inputs might be literal strings representing +/// simple (non-nested) column names. For example: +/// +/// ``` +/// # use delta_kernel::expressions::{column_name, joined_column_name}; +/// assert_eq!(joined_column_name!("a.b", "c"), column_name!("a.b").join(&column_name!("c"))) +/// ``` +/// +/// To avoid accidental misuse, at least one argument must be a string literal. Thus, the following +/// invocation would fail to compile: +/// +/// ```fail_compile +/// # use delta_kernel::expressions::joined_column_name; +/// let s = "s"; +/// let name = joined_column_name!(s, s); +/// ``` +#[macro_export] +#[doc(hidden)] +macro_rules! __joined_column_name { + ( $left:literal, $right:literal ) => { + $crate::__column_name!($left).join(&$crate::__column_name!($right)) + }; + ( $left:literal, $right:expr ) => { + $crate::__column_name!($left).join(&$right) + }; + ( $left:expr, $right:literal) => { + $left.join(&$crate::__column_name!($right)) + }; + ( $($other:tt)* ) => { + compile_error!("joined_column_name!() requires at least one string literal input") + }; +} +#[doc(inline)] +pub use __joined_column_name as joined_column_name; + +#[macro_export] +#[doc(hidden)] +macro_rules! __column_expr { + ( $($name:tt)* ) => { + $crate::expressions::Expression::from($crate::__column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __column_expr as column_expr; + +#[macro_export] +#[doc(hidden)] +macro_rules! __joined_column_expr { + ( $($name:tt)* ) => { + $crate::expressions::Expression::from($crate::__joined_column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __joined_column_expr as joined_column_expr; + +#[cfg(test)] +mod test { + use super::*; + use delta_kernel_derive::parse_column_name; + + #[test] + fn test_parse_column_name_macros() { + assert_eq!(parse_column_name!("a"), ["a"]); + + assert_eq!(parse_column_name!("a"), ["a"]); + assert_eq!(parse_column_name!("a.b"), ["a", "b"]); + assert_eq!(parse_column_name!("a.b.c"), ["a", "b", "c"]); + } + + #[test] + fn test_column_name_macros() { + let simple = column_name!("x"); + let nested = column_name!("x.y"); + + assert_eq!(column_name!("a"), ColumnName::new(["a"])); + assert_eq!(column_name!("a.b"), ColumnName::new(["a", "b"])); + assert_eq!(column_name!("a.b.c"), ColumnName::new(["a", "b", "c"])); + + assert_eq!(joined_column_name!("a", "b"), ColumnName::new(["a", "b"])); + assert_eq!(joined_column_name!("a", "b"), ColumnName::new(["a", "b"])); + + assert_eq!( + joined_column_name!(simple, "b"), + ColumnName::new(["x", "b"]) + ); + assert_eq!( + joined_column_name!(nested, "b"), + ColumnName::new(["x.y", "b"]) + ); + + assert_eq!( + joined_column_name!("a", &simple), + ColumnName::new(["a", "x"]) + ); + assert_eq!( + joined_column_name!("a", &nested), + ColumnName::new(["a", "x.y"]) + ); + } + + #[test] + fn test_column_name_methods() { + let simple = column_name!("x"); + let nested = column_name!("x.y"); + + // path() + assert_eq!(simple.path(), "x"); + assert_eq!(nested.path(), "x.y"); + + // into_inner() + assert_eq!(simple.clone().into_inner(), "x"); + assert_eq!(nested.clone().into_inner(), "x.y"); + + // impl Deref + let name: &str = &nested; + assert_eq!(name, "x.y"); + + // impl> FromIterator + let name: ColumnName = ["x", "y"].into_iter().collect(); + assert_eq!(name, nested); + + // impl FromIterator + let name: ColumnName = [nested, simple].into_iter().collect(); + assert_eq!(name, column_name!("x.y.x")); + } +} diff --git a/kernel/src/expressions/mod.rs b/kernel/src/expressions/mod.rs index 8af25b518..54b90ff5f 100644 --- a/kernel/src/expressions/mod.rs +++ b/kernel/src/expressions/mod.rs @@ -5,9 +5,13 @@ use std::fmt::{Display, Formatter}; use itertools::Itertools; +pub use self::column_names::{ + column_expr, column_name, joined_column_expr, joined_column_name, ColumnName, +}; pub use self::scalars::{ArrayData, Scalar, StructData}; use crate::DataType; +mod column_names; mod scalars; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -133,7 +137,7 @@ pub enum Expression { /// A literal value. Literal(Scalar), /// A column reference by name. - Column(String), + Column(ColumnName), /// A struct computed from a Vec of expressions Struct(Vec), /// A unary operation. @@ -167,11 +171,17 @@ impl> From for Expression { } } +impl From for Expression { + fn from(value: ColumnName) -> Self { + Self::Column(value) + } +} + impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Literal(l) => write!(f, "{}", l), - Self::Column(name) => write!(f, "Column({})", name), + Self::Literal(l) => write!(f, "{l}"), + Self::Column(name) => write!(f, "Column({name})"), Self::Struct(exprs) => write!( f, "Struct({})", @@ -181,11 +191,11 @@ impl Display for Expression { op: BinaryOperator::Distinct, left, right, - } => write!(f, "DISTINCT({}, {})", left, right), - Self::BinaryOperation { op, left, right } => write!(f, "{} {} {}", left, op, right), + } => write!(f, "DISTINCT({left}, {right})"), + Self::BinaryOperation { op, left, right } => write!(f, "{left} {op} {right}"), Self::UnaryOperation { op, expr } => match op { - UnaryOperator::Not => write!(f, "NOT {}", expr), - UnaryOperator::IsNull => write!(f, "{} IS NULL", expr), + UnaryOperator::Not => write!(f, "NOT {expr}"), + UnaryOperator::IsNull => write!(f, "{expr} IS NULL"), }, Self::VariadicOperation { op, exprs } => match op { VariadicOperator::And => { @@ -209,23 +219,18 @@ impl Display for Expression { impl Expression { /// Returns a set of columns referenced by this expression. - pub fn references(&self) -> HashSet<&str> { + pub fn references(&self) -> HashSet<&ColumnName> { let mut set = HashSet::new(); for expr in self.walk() { if let Self::Column(name) = expr { - set.insert(name.as_str()); + set.insert(name); } } set } - /// Create an new expression for a column reference - pub fn column(name: impl ToString) -> Self { - Self::Column(name.to_string()) - } - /// Create a new expression for a literal value pub fn literal(value: impl Into) -> Self { Self::Literal(value.into()) @@ -410,11 +415,11 @@ impl> std::ops::Div for Expression { #[cfg(test)] mod tests { - use super::Expression as Expr; + use super::{column_expr, Expression as Expr}; #[test] fn test_expression_format() { - let col_ref = Expr::column("x"); + let col_ref = column_expr!("x"); let cases = [ (col_ref.clone(), "Column(x)"), (col_ref.clone().eq(2), "Column(x) = 2"), diff --git a/kernel/src/expressions/scalars.rs b/kernel/src/expressions/scalars.rs index 8c934aa3a..3578258e9 100644 --- a/kernel/src/expressions/scalars.rs +++ b/kernel/src/expressions/scalars.rs @@ -463,7 +463,7 @@ impl PrimitiveType { mod tests { use std::f32::consts::PI; - use crate::expressions::BinaryOperator; + use crate::expressions::{column_expr, BinaryOperator}; use crate::Expression; use super::*; @@ -555,7 +555,7 @@ mod tests { elements: vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)], }); - let column = Expression::column("item"); + let column = column_expr!("item"); let array_op = Expression::binary(BinaryOperator::In, 10, array.clone()); let array_not_op = Expression::binary(BinaryOperator::NotIn, 10, array); let column_op = Expression::binary(BinaryOperator::In, PI, column.clone()); diff --git a/kernel/src/scan/data_skipping.rs b/kernel/src/scan/data_skipping.rs index 7b6079081..31aaaab15 100644 --- a/kernel/src/scan/data_skipping.rs +++ b/kernel/src/scan/data_skipping.rs @@ -8,7 +8,8 @@ use crate::actions::get_log_add_schema; use crate::actions::visitors::SelectionVectorVisitor; use crate::error::DeltaResult; use crate::expressions::{ - BinaryOperator, Expression as Expr, ExpressionRef, UnaryOperator, VariadicOperator, + column_expr, column_name, joined_column_expr, BinaryOperator, Expression as Expr, + ExpressionRef, UnaryOperator, VariadicOperator, }; use crate::schema::{DataType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType}; use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler}; @@ -16,10 +17,10 @@ use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler}; /// Get the expression that checks if a col could be null, assuming tight_bounds = true. In this /// case a column can contain null if any value > 0 is in the nullCount. This is further complicated /// by the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_null_expr(null_col: String) -> Expr { +fn get_tight_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::distinct(Expr::column("tightBounds"), false), - Expr::gt(Expr::column(null_col), 0i64), + Expr::distinct(column_expr!("tightBounds"), false), + Expr::gt(null_col, 0i64), ) } @@ -27,20 +28,20 @@ fn get_tight_null_expr(null_col: String) -> Expr { /// case, we can only check if the WHOLE column is null, by checking if the number of records is /// equal to the null count, since all other values of nullCount must be ignored (except 0, which /// doesn't help us) -fn get_wide_null_expr(null_col: String) -> Expr { +fn get_wide_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::eq(Expr::column("tightBounds"), false), - Expr::eq(Expr::column("numRecords"), Expr::column(null_col)), + Expr::eq(column_expr!("tightBounds"), false), + Expr::eq(column_expr!("numRecords"), null_col), ) } /// Get the expression that checks if a col could NOT be null, assuming tight_bounds = true. In this /// case a column has a NOT NULL record if nullCount < numRecords. This is further complicated by /// the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_not_null_expr(null_col: String) -> Expr { +fn get_tight_not_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::distinct(Expr::column("tightBounds"), false), - Expr::lt(Expr::column(null_col), Expr::column("numRecords")), + Expr::distinct(column_expr!("tightBounds"), false), + Expr::lt(null_col, column_expr!("numRecords")), ) } @@ -48,10 +49,10 @@ fn get_tight_not_null_expr(null_col: String) -> Expr { /// this case, we can only check if the WHOLE column null, by checking if the nullCount == /// numRecords. So by inverting that check and seeing if nullCount != numRecords, we can check if /// there is a possibility of a NOT null -fn get_wide_not_null_expr(null_col: String) -> Expr { +fn get_wide_not_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::eq(Expr::column("tightBounds"), false), - Expr::ne(Expr::column("numRecords"), Expr::column(null_col)), + Expr::eq(column_expr!("tightBounds"), false), + Expr::ne(column_expr!("numRecords"), null_col), ) } @@ -65,7 +66,7 @@ fn as_inverted_data_skipping_predicate(expr: &Expr) -> Option { // to check if a column could NOT have a null, we need two different checks, to see // if the bounds are tight and then to actually do the check if let Column(col) = expr.as_ref() { - let null_col = format!("nullCount.{col}"); + let null_col = joined_column_expr!("nullCount", col); Some(Expr::or( get_tight_not_null_expr(null_col.clone()), get_wide_not_null_expr(null_col), @@ -117,8 +118,8 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { _ => return None, // unsupported combination of operands }; let stats_col = match op { - LessThan | LessThanOrEqual => "minValues", - GreaterThan | GreaterThanOrEqual => "maxValues", + LessThan | LessThanOrEqual => column_name!("minValues"), + GreaterThan | GreaterThanOrEqual => column_name!("maxValues"), Equal => { return as_data_skipping_predicate(&Expr::and( Expr::le(Column(col.clone()), Literal(val.clone())), @@ -127,14 +128,13 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { } NotEqual => { return Some(Expr::or( - Expr::gt(Column(format!("minValues.{}", col)), val.clone()), - Expr::lt(Column(format!("maxValues.{}", col)), val.clone()), + Expr::gt(joined_column_expr!("minValues", col), val.clone()), + Expr::lt(joined_column_expr!("maxValues", col), val.clone()), )); } _ => return None, // unsupported operation }; - let col = format!("{}.{}", stats_col, col); - Some(Expr::binary(op, Column(col), val.clone())) + Some(Expr::binary(op, stats_col.join(col), val.clone())) } // push down Not by inverting everything below it UnaryOperation { op: Not, expr } => as_inverted_data_skipping_predicate(expr), @@ -142,7 +142,7 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { // to check if a column could have a null, we need two different checks, to see if // the bounds are tight and then to actually do the check if let Column(col) = expr.as_ref() { - let null_col = format!("nullCount.{col}"); + let null_col = joined_column_expr!("nullCount", col); Some(Expr::or( get_tight_null_expr(null_col.clone()), get_wide_null_expr(null_col), @@ -185,9 +185,9 @@ impl DataSkippingFilter { static PREDICATE_SCHEMA: LazyLock = LazyLock::new(|| { DataType::struct_type([StructField::new("predicate", DataType::BOOLEAN, true)]) }); - static STATS_EXPR: LazyLock = LazyLock::new(|| Expr::column("add.stats")); + static STATS_EXPR: LazyLock = LazyLock::new(|| column_expr!("add.stats")); static FILTER_EXPR: LazyLock = - LazyLock::new(|| Expr::column("predicate").distinct(false)); + LazyLock::new(|| column_expr!("predicate").distinct(false)); let predicate = predicate.as_deref()?; debug!("Creating a data skipping filter for {}", &predicate); @@ -197,7 +197,7 @@ impl DataSkippingFilter { // extracting the corresponding field from the table schema, and inserting that field. let data_fields: Vec<_> = table_schema .fields() - .filter(|field| field_names.contains(&field.name.as_str())) + .filter(|field| field_names.contains(&field.name)) .cloned() .collect(); if data_fields.is_empty() { @@ -308,10 +308,10 @@ mod tests { #[test] fn test_rewrite_basic_comparison() { - let column = Expr::column("a"); + let column = column_expr!("a"); let lit_int = Expr::literal(1_i32); - let min_col = Expr::column("minValues.a"); - let max_col = Expr::column("maxValues.a"); + let min_col = column_expr!("minValues.a"); + let max_col = column_expr!("maxValues.a"); let cases = [ ( diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index 3c52e5e2c..f872a8eca 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -9,7 +9,7 @@ use super::ScanData; use crate::actions::{get_log_schema, ADD_NAME, REMOVE_NAME}; use crate::actions::{visitors::AddVisitor, visitors::RemoveVisitor, Add, Remove}; use crate::engine_data::{GetData, TypedGetData}; -use crate::expressions::{Expression, ExpressionRef}; +use crate::expressions::{column_expr, Expression, ExpressionRef}; use crate::schema::{DataType, MapType, SchemaRef, StructField, StructType}; use crate::{DataVisitor, DeltaResult, Engine, EngineData, ExpressionHandler}; @@ -125,12 +125,12 @@ impl LogReplayScanner { fn get_add_transform_expr(&self) -> Expression { Expression::Struct(vec![ - Expression::column("add.path"), - Expression::column("add.size"), - Expression::column("add.modificationTime"), - Expression::column("add.stats"), - Expression::column("add.deletionVector"), - Expression::Struct(vec![Expression::column("add.partitionValues")]), + column_expr!("add.path"), + column_expr!("add.size"), + column_expr!("add.modificationTime"), + column_expr!("add.stats"), + column_expr!("add.deletionVector"), + Expression::Struct(vec![column_expr!("add.partitionValues")]), ]) } diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index ca9887ead..b5c01e0a4 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -9,7 +9,7 @@ use url::Url; use crate::actions::deletion_vector::{split_vector, treemap_to_bools, DeletionVectorDescriptor}; use crate::actions::{get_log_add_schema, get_log_schema, ADD_NAME, REMOVE_NAME}; -use crate::expressions::{Expression, ExpressionRef, Scalar}; +use crate::expressions::{ColumnName, Expression, ExpressionRef, Scalar}; use crate::features::ColumnMappingMode; use crate::scan::state::{DvInfo, Stats}; use crate::schema::{DataType, Schema, SchemaRef, StructField, StructType}; @@ -160,7 +160,7 @@ impl ScanResult { /// to materialize the partition column. pub enum ColumnType { // A column, selected from the data, as is - Selected(String), + Selected(ColumnName), // A partition column that needs to be added back in Partition(usize), } @@ -421,7 +421,8 @@ fn get_state_info( debug!("\n\n{logical_field:#?}\nAfter mapping: {physical_field:#?}\n\n"); let physical_name = physical_field.name.clone(); read_fields.push(physical_field); - Ok(ColumnType::Selected(physical_name)) + // TODO: Support nested columns! + Ok(ColumnType::Selected(ColumnName::new([physical_name]))) } }) .try_collect()?; @@ -492,7 +493,7 @@ fn transform_to_logical_internal( )?; Ok::(value_expression.into()) } - ColumnType::Selected(field_name) => Ok(Expression::column(field_name)), + ColumnType::Selected(field_name) => Ok(field_name.clone().into()), }) .try_collect()?; let read_expression = Expression::Struct(all_fields); @@ -614,6 +615,7 @@ mod tests { use std::path::PathBuf; use crate::engine::sync::SyncEngine; + use crate::expressions::column_expr; use crate::schema::PrimitiveType; use crate::Table; @@ -757,7 +759,7 @@ mod tests { assert_eq!(data.len(), 1); // Ineffective predicate pushdown attempted, so the one data file should be returned. - let int_col = Expression::column("numeric.ints.int32"); + let int_col = column_expr!("numeric.ints.int32"); let value = Expression::literal(1000i32); let predicate = Arc::new(int_col.clone().gt(value.clone())); let scan = snapshot diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 859ee8921..f119571e2 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -11,6 +11,7 @@ use tracing::{debug, warn}; use url::Url; use crate::actions::{get_log_schema, Metadata, Protocol, METADATA_NAME, PROTOCOL_NAME}; +use crate::expressions::column_expr; use crate::features::{ColumnMappingMode, COLUMN_MAPPING_MODE_KEY}; use crate::path::ParsedLogPath; use crate::scan::ScanBuilder; @@ -111,8 +112,8 @@ impl LogSegment { use Expression as Expr; static META_PREDICATE: LazyLock> = LazyLock::new(|| { Some(Arc::new(Expr::or( - Expr::column("metaData.id").is_not_null(), - Expr::column("protocol.minReaderVersion").is_not_null(), + column_expr!("metaData.id").is_not_null(), + column_expr!("protocol.minReaderVersion").is_not_null(), ))) }); // read the same protocol and metadata schema for both commits and checkpoints diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index f99c9147d..92bf70314 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -13,7 +13,7 @@ use delta_kernel::actions::deletion_vector::split_vector; use delta_kernel::engine::arrow_data::ArrowEngineData; use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; -use delta_kernel::expressions::{BinaryOperator, Expression}; +use delta_kernel::expressions::{column_expr, BinaryOperator, Expression}; use delta_kernel::scan::state::{visit_scan_files, DvInfo, Stats}; use delta_kernel::scan::{transform_to_logical, Scan}; use delta_kernel::schema::Schema; @@ -339,7 +339,7 @@ async fn stats() -> Result<(), Box> { (NotEqual, 8, vec![&batch2, &batch1]), ]; for (op, value, expected_batches) in test_cases { - let predicate = Expression::binary(op, Expression::column("id"), value); + let predicate = Expression::binary(op, column_expr!("id"), value); let scan = snapshot .clone() .scan_builder() @@ -655,27 +655,24 @@ fn table_for_numbers(nums: Vec) -> Vec { fn predicate_on_number() -> Result<(), Box> { let cases = vec![ ( - Expression::column("number").lt(4i64), + column_expr!("number").lt(4i64), table_for_numbers(vec![1, 2, 3]), ), ( - Expression::column("number").le(4i64), + column_expr!("number").le(4i64), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::column("number").gt(4i64), + column_expr!("number").gt(4i64), table_for_numbers(vec![5, 6]), ), ( - Expression::column("number").ge(4i64), + column_expr!("number").ge(4i64), table_for_numbers(vec![4, 5, 6]), ), + (column_expr!("number").eq(4i64), table_for_numbers(vec![4])), ( - Expression::column("number").eq(4i64), - table_for_numbers(vec![4]), - ), - ( - Expression::column("number").ne(4i64), + column_expr!("number").ne(4i64), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ]; @@ -695,27 +692,27 @@ fn predicate_on_number() -> Result<(), Box> { fn predicate_on_number_not() -> Result<(), Box> { let cases = vec![ ( - Expression::not(Expression::column("number").lt(4i64)), + Expression::not(column_expr!("number").lt(4i64)), table_for_numbers(vec![4, 5, 6]), ), ( - Expression::not(Expression::column("number").le(4i64)), + Expression::not(column_expr!("number").le(4i64)), table_for_numbers(vec![5, 6]), ), ( - Expression::not(Expression::column("number").gt(4i64)), + Expression::not(column_expr!("number").gt(4i64)), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::not(Expression::column("number").ge(4i64)), + Expression::not(column_expr!("number").ge(4i64)), table_for_numbers(vec![1, 2, 3]), ), ( - Expression::not(Expression::column("number").eq(4i64)), + Expression::not(column_expr!("number").eq(4i64)), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ( - Expression::not(Expression::column("number").ne(4i64)), + Expression::not(column_expr!("number").ne(4i64)), table_for_numbers(vec![4]), ), ]; @@ -744,8 +741,8 @@ fn predicate_on_number_with_not_null() -> Result<(), Box> "./tests/data/basic_partitioned", Some(&["a_float", "number"]), Some(Expression::and( - Expression::column("number").is_not_null(), - Expression::column("number").lt(Expression::literal(3i64)), + column_expr!("number").is_not_null(), + column_expr!("number").lt(Expression::literal(3i64)), )), expected, )?; @@ -758,7 +755,7 @@ fn predicate_null() -> Result<(), Box> { read_table_data_str( "./tests/data/basic_partitioned", Some(&["a_float", "number"]), - Some(Expression::column("number").is_null()), + Some(column_expr!("number").is_null()), expected, )?; Ok(()) @@ -785,7 +782,7 @@ fn mixed_null() -> Result<(), Box> { read_table_data_str( "./tests/data/mixed-nulls", Some(&["part", "n"]), - Some(Expression::column("n").is_null()), + Some(column_expr!("n").is_null()), expected, )?; Ok(()) @@ -812,7 +809,7 @@ fn mixed_not_null() -> Result<(), Box> { read_table_data_str( "./tests/data/mixed-nulls", Some(&["part", "n"]), - Some(Expression::column("n").is_not_null()), + Some(column_expr!("n").is_not_null()), expected, )?; Ok(()) @@ -822,27 +819,27 @@ fn mixed_not_null() -> Result<(), Box> { fn and_or_predicates() -> Result<(), Box> { let cases = vec![ ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::column("a_float").gt(5.5)), + .and(column_expr!("a_float").gt(5.5)), table_for_numbers(vec![6]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::not(Expression::column("a_float").gt(5.5))), + .and(Expression::not(column_expr!("a_float").gt(5.5))), table_for_numbers(vec![5]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::column("a_float").gt(5.5)), + .or(column_expr!("a_float").gt(5.5)), table_for_numbers(vec![5, 6]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::not(Expression::column("a_float").gt(5.5))), + .or(Expression::not(column_expr!("a_float").gt(5.5))), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ]; @@ -862,33 +859,33 @@ fn not_and_or_predicates() -> Result<(), Box> { let cases = vec![ ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::column("a_float").gt(5.5)), + .and(column_expr!("a_float").gt(5.5)), ), table_for_numbers(vec![1, 2, 3, 4, 5]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::not(Expression::column("a_float").gt(5.5))), + .and(Expression::not(column_expr!("a_float").gt(5.5))), ), table_for_numbers(vec![1, 2, 3, 4, 6]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::column("a_float").gt(5.5)), + .or(column_expr!("a_float").gt(5.5)), ), table_for_numbers(vec![1, 2, 3, 4]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::not(Expression::column("a_float").gt(5.5))), + .or(Expression::not(column_expr!("a_float").gt(5.5))), ), vec![], ), @@ -913,23 +910,23 @@ fn invalid_skips_none_predicates() -> Result<(), Box> { table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").distinct(3i64), + column_expr!("number").distinct(3i64), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").gt(empty_struct.clone()), + column_expr!("number").gt(empty_struct.clone()), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").and(empty_struct.clone().is_null()), + column_expr!("number").and(empty_struct.clone().is_null()), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(Expression::column("number").gt(empty_struct.clone())), + Expression::not(column_expr!("number").gt(empty_struct.clone())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(Expression::column("number").and(empty_struct.clone().is_null())), + Expression::not(column_expr!("number").and(empty_struct.clone().is_null())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ]; @@ -963,7 +960,7 @@ fn with_predicate_and_removes() -> Result<(), Box> { read_table_data_str( "./tests/data/table-with-dv-small/", None, - Some(Expression::gt(Expression::column("value"), 3)), + Some(Expression::gt(column_expr!("value"), 3)), expected, )?; Ok(())