diff --git a/derive-macros/src/lib.rs b/derive-macros/src/lib.rs index 9bf74690b..3b2e35aa4 100644 --- a/derive-macros/src/lib.rs +++ b/derive-macros/src/lib.rs @@ -5,6 +5,26 @@ use syn::{ parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Meta, PathArguments, Type, }; +/// Parses a dot-delimited column name into an array of field names. See +/// [`delta_kernel::expressions::column_name::column_name`] macro for details. +#[proc_macro] +pub fn parse_column_name(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let is_valid = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.'; + let err = match syn::parse(input) { + Ok(syn::Lit::Str(name)) => match name.value().chars().find(|c| !is_valid(*c)) { + Some(bad_char) => Error::new(name.span(), format!("Invalid character: {bad_char:?}")), + _ => { + let path = name.value(); + let path = path.split('.').map(proc_macro2::Literal::string); + return quote_spanned! { name.span() => [#(#path),*] }.into(); + } + }, + Ok(lit) => Error::new(lit.span(), "Expected a string literal"), + Err(err) => err, + }; + err.into_compile_error().into() +} + /// Derive a `delta_kernel::schemas::ToDataType` implementation for the annotated struct. The actual /// field names in the schema (and therefore of the struct members) are all mandated by the Delta /// spec, and so the user of this macro is responsible for ensuring that diff --git a/ffi/src/expressions.rs b/ffi/src/expressions.rs index 087cae163..19c334038 100644 --- a/ffi/src/expressions.rs +++ b/ffi/src/expressions.rs @@ -5,7 +5,7 @@ use crate::{ ReferenceSet, TryFromStringSlice, }; use delta_kernel::{ - expressions::{BinaryOperator, Expression, UnaryOperator}, + expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator}, DeltaResult, }; @@ -146,7 +146,9 @@ fn visit_expression_column_impl( state: &mut KernelExpressionVisitorState, name: DeltaResult, ) -> DeltaResult { - Ok(wrap_expression(state, Expression::Column(name?))) + // TODO: FIXME: This is incorrect if any field name in the column path contains a period. + let name = ColumnName::new(name?.split('.')).into(); + Ok(wrap_expression(state, name)) } #[no_mangle] diff --git a/kernel/src/actions/set_transaction.rs b/kernel/src/actions/set_transaction.rs index 5dcba4323..5cfae8863 100644 --- a/kernel/src/actions/set_transaction.rs +++ b/kernel/src/actions/set_transaction.rs @@ -2,8 +2,9 @@ use std::sync::{Arc, LazyLock}; use crate::actions::visitors::SetTransactionVisitor; use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME}; +use crate::expressions::column_expr; use crate::snapshot::Snapshot; -use crate::{DeltaResult, Engine, EngineData, Expression, ExpressionRef, SchemaRef}; +use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef}; pub use crate::actions::visitors::SetTransactionMap; pub struct SetTransactionScanner { @@ -53,7 +54,7 @@ impl SetTransactionScanner { // point filtering by a particular app id, even if we have one, because app ids are all in // the a single checkpoint part having large min/max range (because they're usually uuids). static META_PREDICATE: LazyLock> = - LazyLock::new(|| Some(Arc::new(Expression::column("txn.appId").is_not_null()))); + LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null()))); self.snapshot .log_segment .replay(engine, schema.clone(), schema, META_PREDICATE.clone()) diff --git a/kernel/src/engine/arrow_expression.rs b/kernel/src/engine/arrow_expression.rs index fa87cb8bb..cb9138299 100644 --- a/kernel/src/engine/arrow_expression.rs +++ b/kernel/src/engine/arrow_expression.rs @@ -589,9 +589,9 @@ mod tests { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item")); + let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); - let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item")); + let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item")); let result = evaluate_expression(¬_op, &batch, None).unwrap(); let expected = BooleanArray::from(vec![true, false, true]); @@ -609,7 +609,7 @@ mod tests { let schema = Schema::new([field.clone()]); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap(); - let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item")); + let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item")); let in_result = evaluate_expression(&in_op, &batch, None); @@ -654,8 +654,8 @@ mod tests { let in_op = Expression::binary( BinaryOperator::NotIn, - Expression::column("item"), - Expression::column("item"), + column_expr!("item"), + column_expr!("item"), ); let in_result = evaluate_expression(&in_op, &batch, None); @@ -679,10 +679,9 @@ mod tests { let array = ListArray::new(field.clone(), offsets, Arc::new(values), None); let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap(); - let str_not_op = - Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item")); + let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item")); - let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item")); + let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item")); let result = evaluate_expression(&str_in_op, &batch, None).unwrap(); let expected = BooleanArray::from(vec![true, true, true]); @@ -699,7 +698,7 @@ mod tests { let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let results = evaluate_expression(&column, &batch, None).unwrap(); assert_eq!(results.as_ref(), &values); @@ -720,7 +719,7 @@ mod tests { vec![Arc::new(struct_array.clone())], ) .unwrap(); - let column = Expression::column("b.a"); + let column = column_expr!("b.a"); let results = evaluate_expression(&column, &batch, None).unwrap(); assert_eq!(results.as_ref(), &values); } @@ -730,7 +729,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let expression = column.clone().add(1); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -766,8 +765,8 @@ mod tests { vec![Arc::new(values.clone()), Arc::new(values)], ) .unwrap(); - let column_a = Expression::column("a"); - let column_b = Expression::column("b"); + let column_a = column_expr!("a"); + let column_b = column_expr!("b"); let expression = column_a.clone().add(column_b.clone()); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -790,7 +789,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let values = Int32Array::from(vec![1, 2, 3]); let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap(); - let column = Expression::column("a"); + let column = column_expr!("a"); let expression = column.clone().lt(2); let results = evaluate_expression(&expression, &batch, None).unwrap(); @@ -837,8 +836,8 @@ mod tests { ], ) .unwrap(); - let column_a = Expression::column("a"); - let column_b = Expression::column("b"); + let column_a = column_expr!("a"); + let column_b = column_expr!("b"); let expression = column_a.clone().and(column_b.clone()); let results = diff --git a/kernel/src/engine/parquet_row_group_skipping/tests.rs b/kernel/src/engine/parquet_row_group_skipping/tests.rs index 6f5dd3a48..19bd2b5bf 100644 --- a/kernel/src/engine/parquet_row_group_skipping/tests.rs +++ b/kernel/src/engine/parquet_row_group_skipping/tests.rs @@ -1,4 +1,5 @@ use super::*; +use crate::expressions::column_expr; use crate::Expression; use parquet::arrow::arrow_reader::ArrowReaderMetadata; use std::fs::File; @@ -39,21 +40,21 @@ fn test_get_stat_values() { // The expression doesn't matter -- it just needs to mention all the columns we care about. let columns = Expression::and_from(vec![ - Expression::column("varlen.utf8"), - Expression::column("numeric.ints.int64"), - Expression::column("numeric.ints.int32"), - Expression::column("numeric.ints.int16"), - Expression::column("numeric.ints.int8"), - Expression::column("numeric.floats.float32"), - Expression::column("numeric.floats.float64"), - Expression::column("bool"), - Expression::column("varlen.binary"), - Expression::column("numeric.decimals.decimal32"), - Expression::column("numeric.decimals.decimal64"), - Expression::column("numeric.decimals.decimal128"), - Expression::column("chrono.date32"), - Expression::column("chrono.timestamp"), - Expression::column("chrono.timestamp_ntz"), + column_expr!("varlen.utf8"), + column_expr!("numeric.ints.int64"), + column_expr!("numeric.ints.int32"), + column_expr!("numeric.ints.int16"), + column_expr!("numeric.ints.int8"), + column_expr!("numeric.floats.float32"), + column_expr!("numeric.floats.float64"), + column_expr!("bool"), + column_expr!("varlen.binary"), + column_expr!("numeric.decimals.decimal32"), + column_expr!("numeric.decimals.decimal64"), + column_expr!("numeric.decimals.decimal128"), + column_expr!("chrono.date32"), + column_expr!("chrono.timestamp"), + column_expr!("chrono.timestamp_ntz"), ]); let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns); diff --git a/kernel/src/engine/parquet_stats_skipping/tests.rs b/kernel/src/engine/parquet_stats_skipping/tests.rs index a95ac4102..b4bdd97d3 100644 --- a/kernel/src/engine/parquet_stats_skipping/tests.rs +++ b/kernel/src/engine/parquet_stats_skipping/tests.rs @@ -1,5 +1,5 @@ use super::*; -use crate::expressions::{ArrayData, StructData}; +use crate::expressions::{column_expr, ArrayData, StructData}; use crate::schema::ArrayType; use crate::DataType; @@ -337,7 +337,7 @@ fn test_binary_eq_ne() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { // negative test -- mismatched column type @@ -485,7 +485,7 @@ fn test_binary_lt_ge() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { expect_eq!( @@ -585,7 +585,7 @@ fn test_binary_le_gt() { const LO: Scalar = Scalar::Long(1); const MID: Scalar = Scalar::Long(10); const HI: Scalar = Scalar::Long(100); - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { // negative test -- mismatched column type @@ -736,7 +736,7 @@ impl ParquetStatsSkippingFilter for NullCountTestFilter { fn test_not_null() { use UnaryOperator::IsNull; - let col = &Expression::column("x"); + let col = &column_expr!("x"); for inverted in [false, true] { expect_eq!( NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted), @@ -809,7 +809,7 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter { #[test] fn test_sql_where() { - let col = &Expression::column("x"); + let col = &column_expr!("x"); let val = &Expression::literal(1); const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN)); const FALSE: Expression = Expression::Literal(Scalar::Boolean(false)); diff --git a/kernel/src/expressions/column_names.rs b/kernel/src/expressions/column_names.rs new file mode 100644 index 000000000..4129abb30 --- /dev/null +++ b/kernel/src/expressions/column_names.rs @@ -0,0 +1,269 @@ +use std::borrow::Borrow; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::ops::Deref; +/// A (possibly nested) column name. +// TODO: Track name as a path rather than a single string +#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord)] +pub struct ColumnName { + path: String, +} + +impl ColumnName { + /// Constructs a new column name from an iterator of field names. The field names are joined + /// together to make a single path. + pub fn new(path: impl IntoIterator>) -> Self { + let path: Vec<_> = path.into_iter().map(Into::into).collect(); + let path = path.join("."); + Self { path } + } + + /// Joins this column with another, concatenating their fields into a single nested column path. + /// + /// NOTE: This is a convenience method that copies two arguments without consuming them. If more + /// arguments are needed, or if performance is a concern, it is recommended to use + /// [`FromIterator for ColumnName`](#impl-FromIterator-for-ColumnName) instead: + /// + /// ``` + /// # use delta_kernel::expressions::ColumnName; + /// let x = ColumnName::new(["a", "b"]); + /// let y = ColumnName::new(["c", "d"]); + /// let joined: ColumnName = [x, y].into_iter().collect(); + /// assert_eq!(joined, ColumnName::new(["a", "b", "c", "d"])); + /// ``` + pub fn join(&self, right: &ColumnName) -> ColumnName { + [self.clone(), right.clone()].into_iter().collect() + } + + /// The path of field names for this column name + pub fn path(&self) -> &String { + &self.path + } + + /// Consumes this column name and returns the path of field names. + pub fn into_inner(self) -> String { + self.path + } +} + +/// Creates a new column name from a path of field names. Each field name is taken as-is, and may +/// contain arbitrary characters (including periods, spaces, etc.). +impl> FromIterator for ColumnName { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self::new(iter) + } +} + +/// Creates a new column name by joining multiple column names together. +impl FromIterator for ColumnName { + fn from_iter(iter: T) -> Self + where + T: IntoIterator, + { + Self::new(iter.into_iter().map(ColumnName::into_inner)) + } +} + +impl Display for ColumnName { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + (**self).fmt(f) + } +} + +impl Deref for ColumnName { + type Target = String; + + fn deref(&self) -> &String { + &self.path + } +} + +// Allows searching collections of `ColumnName` without an owned key value +impl Borrow for ColumnName { + fn borrow(&self) -> &String { + self + } +} + +// Allows searching collections of `&ColumnName` without an owned key value. Needed because there is +// apparently no blanket `impl Borrow for &T where T: Borrow`, even tho `Eq` [1] and +// `Hash` [2] both have blanket impl for treating `&T` like `T`. +// +// [1] https://doc.rust-lang.org/std/cmp/trait.Eq.html#impl-Eq-for-%26A +// [2] https://doc.rust-lang.org/std/hash/trait.Hash.html#impl-Hash-for-%26T +impl Borrow for &ColumnName { + fn borrow(&self) -> &String { + self + } +} + +impl Hash for ColumnName { + fn hash(&self, hasher: &mut H) { + (**self).hash(hasher) + } +} + +/// Creates a nested column name whose field names are all simple column names (containing only +/// alphanumeric characters and underscores), delimited by dots. This macro is provided as a +/// convenience for the common case where the caller knows the column name contains only simple +/// field names and that splitting by periods is safe: +/// +/// ``` +/// # use delta_kernel::expressions::{column_name, ColumnName}; +/// assert_eq!(column_name!("a.b.c"), ColumnName::new(["a", "b", "c"])); +/// ``` +/// +/// To avoid accidental misuse, the argument must be a string literal, so the compiler can validate +/// the safety conditions. Thus, the following uses would fail to compile: +/// +/// ```fail_compile +/// # use delta_kernel::expressions::column_name; +/// let s = "a.b"; +/// let name = column_name!(s); // not a string literal +/// ``` +/// +/// ```fail_compile +/// # use delta_kernel::expressions::simple_column_name; +/// let name = simple_column_name!("a b"); // non-alphanumeric character +/// ``` +// NOTE: Macros are only public if exported, which defines them at the root of the crate. But we +// don't want it there. So, we export a hidden macro and pub use it here where we actually want it. +#[macro_export] +#[doc(hidden)] +macro_rules! __column_name { + ( $($name:tt)* ) => { + $crate::expressions::ColumnName::new(delta_kernel_derive::parse_column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __column_name as column_name; + +/// Joins two column names together, when one or both inputs might be literal strings representing +/// simple (non-nested) column names. For example: +/// +/// ``` +/// # use delta_kernel::expressions::{column_name, joined_column_name}; +/// assert_eq!(joined_column_name!("a.b", "c"), column_name!("a.b").join(&column_name!("c"))) +/// ``` +/// +/// To avoid accidental misuse, at least one argument must be a string literal. Thus, the following +/// invocation would fail to compile: +/// +/// ```fail_compile +/// # use delta_kernel::expressions::joined_column_name; +/// let s = "s"; +/// let name = joined_column_name!(s, s); +/// ``` +#[macro_export] +#[doc(hidden)] +macro_rules! __joined_column_name { + ( $left:literal, $right:literal ) => { + $crate::__column_name!($left).join(&$crate::__column_name!($right)) + }; + ( $left:literal, $right:expr ) => { + $crate::__column_name!($left).join(&$right) + }; + ( $left:expr, $right:literal) => { + $left.join(&$crate::__column_name!($right)) + }; + ( $($other:tt)* ) => { + compile_error!("joined_column_name!() requires at least one string literal input") + }; +} +#[doc(inline)] +pub use __joined_column_name as joined_column_name; + +#[macro_export] +#[doc(hidden)] +macro_rules! __column_expr { + ( $($name:tt)* ) => { + $crate::expressions::Expression::from($crate::__column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __column_expr as column_expr; + +#[macro_export] +#[doc(hidden)] +macro_rules! __joined_column_expr { + ( $($name:tt)* ) => { + $crate::expressions::Expression::from($crate::__joined_column_name!($($name)*)) + }; +} +#[doc(inline)] +pub use __joined_column_expr as joined_column_expr; + +#[cfg(test)] +mod test { + use super::*; + use delta_kernel_derive::parse_column_name; + + #[test] + fn test_parse_column_name_macros() { + assert_eq!(parse_column_name!("a"), ["a"]); + + assert_eq!(parse_column_name!("a"), ["a"]); + assert_eq!(parse_column_name!("a.b"), ["a", "b"]); + assert_eq!(parse_column_name!("a.b.c"), ["a", "b", "c"]); + } + + #[test] + fn test_column_name_macros() { + let simple = column_name!("x"); + let nested = column_name!("x.y"); + + assert_eq!(column_name!("a"), ColumnName::new(["a"])); + assert_eq!(column_name!("a.b"), ColumnName::new(["a", "b"])); + assert_eq!(column_name!("a.b.c"), ColumnName::new(["a", "b", "c"])); + + assert_eq!(joined_column_name!("a", "b"), ColumnName::new(["a", "b"])); + assert_eq!(joined_column_name!("a", "b"), ColumnName::new(["a", "b"])); + + assert_eq!( + joined_column_name!(simple, "b"), + ColumnName::new(["x", "b"]) + ); + assert_eq!( + joined_column_name!(nested, "b"), + ColumnName::new(["x.y", "b"]) + ); + + assert_eq!( + joined_column_name!("a", &simple), + ColumnName::new(["a", "x"]) + ); + assert_eq!( + joined_column_name!("a", &nested), + ColumnName::new(["a", "x.y"]) + ); + } + + #[test] + fn test_column_name_methods() { + let simple = column_name!("x"); + let nested = column_name!("x.y"); + + // path() + assert_eq!(simple.path(), "x"); + assert_eq!(nested.path(), "x.y"); + + // into_inner() + assert_eq!(simple.clone().into_inner(), "x"); + assert_eq!(nested.clone().into_inner(), "x.y"); + + // impl Deref + let name: &str = &nested; + assert_eq!(name, "x.y"); + + // impl> FromIterator + let name: ColumnName = ["x", "y"].into_iter().collect(); + assert_eq!(name, nested); + + // impl FromIterator + let name: ColumnName = [nested, simple].into_iter().collect(); + assert_eq!(name, column_name!("x.y.x")); + } +} diff --git a/kernel/src/expressions/mod.rs b/kernel/src/expressions/mod.rs index 8af25b518..54b90ff5f 100644 --- a/kernel/src/expressions/mod.rs +++ b/kernel/src/expressions/mod.rs @@ -5,9 +5,13 @@ use std::fmt::{Display, Formatter}; use itertools::Itertools; +pub use self::column_names::{ + column_expr, column_name, joined_column_expr, joined_column_name, ColumnName, +}; pub use self::scalars::{ArrayData, Scalar, StructData}; use crate::DataType; +mod column_names; mod scalars; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -133,7 +137,7 @@ pub enum Expression { /// A literal value. Literal(Scalar), /// A column reference by name. - Column(String), + Column(ColumnName), /// A struct computed from a Vec of expressions Struct(Vec), /// A unary operation. @@ -167,11 +171,17 @@ impl> From for Expression { } } +impl From for Expression { + fn from(value: ColumnName) -> Self { + Self::Column(value) + } +} + impl Display for Expression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - Self::Literal(l) => write!(f, "{}", l), - Self::Column(name) => write!(f, "Column({})", name), + Self::Literal(l) => write!(f, "{l}"), + Self::Column(name) => write!(f, "Column({name})"), Self::Struct(exprs) => write!( f, "Struct({})", @@ -181,11 +191,11 @@ impl Display for Expression { op: BinaryOperator::Distinct, left, right, - } => write!(f, "DISTINCT({}, {})", left, right), - Self::BinaryOperation { op, left, right } => write!(f, "{} {} {}", left, op, right), + } => write!(f, "DISTINCT({left}, {right})"), + Self::BinaryOperation { op, left, right } => write!(f, "{left} {op} {right}"), Self::UnaryOperation { op, expr } => match op { - UnaryOperator::Not => write!(f, "NOT {}", expr), - UnaryOperator::IsNull => write!(f, "{} IS NULL", expr), + UnaryOperator::Not => write!(f, "NOT {expr}"), + UnaryOperator::IsNull => write!(f, "{expr} IS NULL"), }, Self::VariadicOperation { op, exprs } => match op { VariadicOperator::And => { @@ -209,23 +219,18 @@ impl Display for Expression { impl Expression { /// Returns a set of columns referenced by this expression. - pub fn references(&self) -> HashSet<&str> { + pub fn references(&self) -> HashSet<&ColumnName> { let mut set = HashSet::new(); for expr in self.walk() { if let Self::Column(name) = expr { - set.insert(name.as_str()); + set.insert(name); } } set } - /// Create an new expression for a column reference - pub fn column(name: impl ToString) -> Self { - Self::Column(name.to_string()) - } - /// Create a new expression for a literal value pub fn literal(value: impl Into) -> Self { Self::Literal(value.into()) @@ -410,11 +415,11 @@ impl> std::ops::Div for Expression { #[cfg(test)] mod tests { - use super::Expression as Expr; + use super::{column_expr, Expression as Expr}; #[test] fn test_expression_format() { - let col_ref = Expr::column("x"); + let col_ref = column_expr!("x"); let cases = [ (col_ref.clone(), "Column(x)"), (col_ref.clone().eq(2), "Column(x) = 2"), diff --git a/kernel/src/expressions/scalars.rs b/kernel/src/expressions/scalars.rs index 8c934aa3a..3578258e9 100644 --- a/kernel/src/expressions/scalars.rs +++ b/kernel/src/expressions/scalars.rs @@ -463,7 +463,7 @@ impl PrimitiveType { mod tests { use std::f32::consts::PI; - use crate::expressions::BinaryOperator; + use crate::expressions::{column_expr, BinaryOperator}; use crate::Expression; use super::*; @@ -555,7 +555,7 @@ mod tests { elements: vec![Scalar::Integer(1), Scalar::Integer(2), Scalar::Integer(3)], }); - let column = Expression::column("item"); + let column = column_expr!("item"); let array_op = Expression::binary(BinaryOperator::In, 10, array.clone()); let array_not_op = Expression::binary(BinaryOperator::NotIn, 10, array); let column_op = Expression::binary(BinaryOperator::In, PI, column.clone()); diff --git a/kernel/src/scan/data_skipping.rs b/kernel/src/scan/data_skipping.rs index 7b6079081..31aaaab15 100644 --- a/kernel/src/scan/data_skipping.rs +++ b/kernel/src/scan/data_skipping.rs @@ -8,7 +8,8 @@ use crate::actions::get_log_add_schema; use crate::actions::visitors::SelectionVectorVisitor; use crate::error::DeltaResult; use crate::expressions::{ - BinaryOperator, Expression as Expr, ExpressionRef, UnaryOperator, VariadicOperator, + column_expr, column_name, joined_column_expr, BinaryOperator, Expression as Expr, + ExpressionRef, UnaryOperator, VariadicOperator, }; use crate::schema::{DataType, PrimitiveType, SchemaRef, SchemaTransform, StructField, StructType}; use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler}; @@ -16,10 +17,10 @@ use crate::{Engine, EngineData, ExpressionEvaluator, JsonHandler}; /// Get the expression that checks if a col could be null, assuming tight_bounds = true. In this /// case a column can contain null if any value > 0 is in the nullCount. This is further complicated /// by the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_null_expr(null_col: String) -> Expr { +fn get_tight_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::distinct(Expr::column("tightBounds"), false), - Expr::gt(Expr::column(null_col), 0i64), + Expr::distinct(column_expr!("tightBounds"), false), + Expr::gt(null_col, 0i64), ) } @@ -27,20 +28,20 @@ fn get_tight_null_expr(null_col: String) -> Expr { /// case, we can only check if the WHOLE column is null, by checking if the number of records is /// equal to the null count, since all other values of nullCount must be ignored (except 0, which /// doesn't help us) -fn get_wide_null_expr(null_col: String) -> Expr { +fn get_wide_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::eq(Expr::column("tightBounds"), false), - Expr::eq(Expr::column("numRecords"), Expr::column(null_col)), + Expr::eq(column_expr!("tightBounds"), false), + Expr::eq(column_expr!("numRecords"), null_col), ) } /// Get the expression that checks if a col could NOT be null, assuming tight_bounds = true. In this /// case a column has a NOT NULL record if nullCount < numRecords. This is further complicated by /// the default for tightBounds being true, so we have to check if it's EITHER `null` OR `true` -fn get_tight_not_null_expr(null_col: String) -> Expr { +fn get_tight_not_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::distinct(Expr::column("tightBounds"), false), - Expr::lt(Expr::column(null_col), Expr::column("numRecords")), + Expr::distinct(column_expr!("tightBounds"), false), + Expr::lt(null_col, column_expr!("numRecords")), ) } @@ -48,10 +49,10 @@ fn get_tight_not_null_expr(null_col: String) -> Expr { /// this case, we can only check if the WHOLE column null, by checking if the nullCount == /// numRecords. So by inverting that check and seeing if nullCount != numRecords, we can check if /// there is a possibility of a NOT null -fn get_wide_not_null_expr(null_col: String) -> Expr { +fn get_wide_not_null_expr(null_col: Expr) -> Expr { Expr::and( - Expr::eq(Expr::column("tightBounds"), false), - Expr::ne(Expr::column("numRecords"), Expr::column(null_col)), + Expr::eq(column_expr!("tightBounds"), false), + Expr::ne(column_expr!("numRecords"), null_col), ) } @@ -65,7 +66,7 @@ fn as_inverted_data_skipping_predicate(expr: &Expr) -> Option { // to check if a column could NOT have a null, we need two different checks, to see // if the bounds are tight and then to actually do the check if let Column(col) = expr.as_ref() { - let null_col = format!("nullCount.{col}"); + let null_col = joined_column_expr!("nullCount", col); Some(Expr::or( get_tight_not_null_expr(null_col.clone()), get_wide_not_null_expr(null_col), @@ -117,8 +118,8 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { _ => return None, // unsupported combination of operands }; let stats_col = match op { - LessThan | LessThanOrEqual => "minValues", - GreaterThan | GreaterThanOrEqual => "maxValues", + LessThan | LessThanOrEqual => column_name!("minValues"), + GreaterThan | GreaterThanOrEqual => column_name!("maxValues"), Equal => { return as_data_skipping_predicate(&Expr::and( Expr::le(Column(col.clone()), Literal(val.clone())), @@ -127,14 +128,13 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { } NotEqual => { return Some(Expr::or( - Expr::gt(Column(format!("minValues.{}", col)), val.clone()), - Expr::lt(Column(format!("maxValues.{}", col)), val.clone()), + Expr::gt(joined_column_expr!("minValues", col), val.clone()), + Expr::lt(joined_column_expr!("maxValues", col), val.clone()), )); } _ => return None, // unsupported operation }; - let col = format!("{}.{}", stats_col, col); - Some(Expr::binary(op, Column(col), val.clone())) + Some(Expr::binary(op, stats_col.join(col), val.clone())) } // push down Not by inverting everything below it UnaryOperation { op: Not, expr } => as_inverted_data_skipping_predicate(expr), @@ -142,7 +142,7 @@ fn as_data_skipping_predicate(expr: &Expr) -> Option { // to check if a column could have a null, we need two different checks, to see if // the bounds are tight and then to actually do the check if let Column(col) = expr.as_ref() { - let null_col = format!("nullCount.{col}"); + let null_col = joined_column_expr!("nullCount", col); Some(Expr::or( get_tight_null_expr(null_col.clone()), get_wide_null_expr(null_col), @@ -185,9 +185,9 @@ impl DataSkippingFilter { static PREDICATE_SCHEMA: LazyLock = LazyLock::new(|| { DataType::struct_type([StructField::new("predicate", DataType::BOOLEAN, true)]) }); - static STATS_EXPR: LazyLock = LazyLock::new(|| Expr::column("add.stats")); + static STATS_EXPR: LazyLock = LazyLock::new(|| column_expr!("add.stats")); static FILTER_EXPR: LazyLock = - LazyLock::new(|| Expr::column("predicate").distinct(false)); + LazyLock::new(|| column_expr!("predicate").distinct(false)); let predicate = predicate.as_deref()?; debug!("Creating a data skipping filter for {}", &predicate); @@ -197,7 +197,7 @@ impl DataSkippingFilter { // extracting the corresponding field from the table schema, and inserting that field. let data_fields: Vec<_> = table_schema .fields() - .filter(|field| field_names.contains(&field.name.as_str())) + .filter(|field| field_names.contains(&field.name)) .cloned() .collect(); if data_fields.is_empty() { @@ -308,10 +308,10 @@ mod tests { #[test] fn test_rewrite_basic_comparison() { - let column = Expr::column("a"); + let column = column_expr!("a"); let lit_int = Expr::literal(1_i32); - let min_col = Expr::column("minValues.a"); - let max_col = Expr::column("maxValues.a"); + let min_col = column_expr!("minValues.a"); + let max_col = column_expr!("maxValues.a"); let cases = [ ( diff --git a/kernel/src/scan/log_replay.rs b/kernel/src/scan/log_replay.rs index 3c52e5e2c..f872a8eca 100644 --- a/kernel/src/scan/log_replay.rs +++ b/kernel/src/scan/log_replay.rs @@ -9,7 +9,7 @@ use super::ScanData; use crate::actions::{get_log_schema, ADD_NAME, REMOVE_NAME}; use crate::actions::{visitors::AddVisitor, visitors::RemoveVisitor, Add, Remove}; use crate::engine_data::{GetData, TypedGetData}; -use crate::expressions::{Expression, ExpressionRef}; +use crate::expressions::{column_expr, Expression, ExpressionRef}; use crate::schema::{DataType, MapType, SchemaRef, StructField, StructType}; use crate::{DataVisitor, DeltaResult, Engine, EngineData, ExpressionHandler}; @@ -125,12 +125,12 @@ impl LogReplayScanner { fn get_add_transform_expr(&self) -> Expression { Expression::Struct(vec![ - Expression::column("add.path"), - Expression::column("add.size"), - Expression::column("add.modificationTime"), - Expression::column("add.stats"), - Expression::column("add.deletionVector"), - Expression::Struct(vec![Expression::column("add.partitionValues")]), + column_expr!("add.path"), + column_expr!("add.size"), + column_expr!("add.modificationTime"), + column_expr!("add.stats"), + column_expr!("add.deletionVector"), + Expression::Struct(vec![column_expr!("add.partitionValues")]), ]) } diff --git a/kernel/src/scan/mod.rs b/kernel/src/scan/mod.rs index ca9887ead..b5c01e0a4 100644 --- a/kernel/src/scan/mod.rs +++ b/kernel/src/scan/mod.rs @@ -9,7 +9,7 @@ use url::Url; use crate::actions::deletion_vector::{split_vector, treemap_to_bools, DeletionVectorDescriptor}; use crate::actions::{get_log_add_schema, get_log_schema, ADD_NAME, REMOVE_NAME}; -use crate::expressions::{Expression, ExpressionRef, Scalar}; +use crate::expressions::{ColumnName, Expression, ExpressionRef, Scalar}; use crate::features::ColumnMappingMode; use crate::scan::state::{DvInfo, Stats}; use crate::schema::{DataType, Schema, SchemaRef, StructField, StructType}; @@ -160,7 +160,7 @@ impl ScanResult { /// to materialize the partition column. pub enum ColumnType { // A column, selected from the data, as is - Selected(String), + Selected(ColumnName), // A partition column that needs to be added back in Partition(usize), } @@ -421,7 +421,8 @@ fn get_state_info( debug!("\n\n{logical_field:#?}\nAfter mapping: {physical_field:#?}\n\n"); let physical_name = physical_field.name.clone(); read_fields.push(physical_field); - Ok(ColumnType::Selected(physical_name)) + // TODO: Support nested columns! + Ok(ColumnType::Selected(ColumnName::new([physical_name]))) } }) .try_collect()?; @@ -492,7 +493,7 @@ fn transform_to_logical_internal( )?; Ok::(value_expression.into()) } - ColumnType::Selected(field_name) => Ok(Expression::column(field_name)), + ColumnType::Selected(field_name) => Ok(field_name.clone().into()), }) .try_collect()?; let read_expression = Expression::Struct(all_fields); @@ -614,6 +615,7 @@ mod tests { use std::path::PathBuf; use crate::engine::sync::SyncEngine; + use crate::expressions::column_expr; use crate::schema::PrimitiveType; use crate::Table; @@ -757,7 +759,7 @@ mod tests { assert_eq!(data.len(), 1); // Ineffective predicate pushdown attempted, so the one data file should be returned. - let int_col = Expression::column("numeric.ints.int32"); + let int_col = column_expr!("numeric.ints.int32"); let value = Expression::literal(1000i32); let predicate = Arc::new(int_col.clone().gt(value.clone())); let scan = snapshot diff --git a/kernel/src/snapshot.rs b/kernel/src/snapshot.rs index 859ee8921..f119571e2 100644 --- a/kernel/src/snapshot.rs +++ b/kernel/src/snapshot.rs @@ -11,6 +11,7 @@ use tracing::{debug, warn}; use url::Url; use crate::actions::{get_log_schema, Metadata, Protocol, METADATA_NAME, PROTOCOL_NAME}; +use crate::expressions::column_expr; use crate::features::{ColumnMappingMode, COLUMN_MAPPING_MODE_KEY}; use crate::path::ParsedLogPath; use crate::scan::ScanBuilder; @@ -111,8 +112,8 @@ impl LogSegment { use Expression as Expr; static META_PREDICATE: LazyLock> = LazyLock::new(|| { Some(Arc::new(Expr::or( - Expr::column("metaData.id").is_not_null(), - Expr::column("protocol.minReaderVersion").is_not_null(), + column_expr!("metaData.id").is_not_null(), + column_expr!("protocol.minReaderVersion").is_not_null(), ))) }); // read the same protocol and metadata schema for both commits and checkpoints diff --git a/kernel/tests/read.rs b/kernel/tests/read.rs index f99c9147d..92bf70314 100644 --- a/kernel/tests/read.rs +++ b/kernel/tests/read.rs @@ -13,7 +13,7 @@ use delta_kernel::actions::deletion_vector::split_vector; use delta_kernel::engine::arrow_data::ArrowEngineData; use delta_kernel::engine::default::executor::tokio::TokioBackgroundExecutor; use delta_kernel::engine::default::DefaultEngine; -use delta_kernel::expressions::{BinaryOperator, Expression}; +use delta_kernel::expressions::{column_expr, BinaryOperator, Expression}; use delta_kernel::scan::state::{visit_scan_files, DvInfo, Stats}; use delta_kernel::scan::{transform_to_logical, Scan}; use delta_kernel::schema::Schema; @@ -339,7 +339,7 @@ async fn stats() -> Result<(), Box> { (NotEqual, 8, vec![&batch2, &batch1]), ]; for (op, value, expected_batches) in test_cases { - let predicate = Expression::binary(op, Expression::column("id"), value); + let predicate = Expression::binary(op, column_expr!("id"), value); let scan = snapshot .clone() .scan_builder() @@ -655,27 +655,24 @@ fn table_for_numbers(nums: Vec) -> Vec { fn predicate_on_number() -> Result<(), Box> { let cases = vec![ ( - Expression::column("number").lt(4i64), + column_expr!("number").lt(4i64), table_for_numbers(vec![1, 2, 3]), ), ( - Expression::column("number").le(4i64), + column_expr!("number").le(4i64), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::column("number").gt(4i64), + column_expr!("number").gt(4i64), table_for_numbers(vec![5, 6]), ), ( - Expression::column("number").ge(4i64), + column_expr!("number").ge(4i64), table_for_numbers(vec![4, 5, 6]), ), + (column_expr!("number").eq(4i64), table_for_numbers(vec![4])), ( - Expression::column("number").eq(4i64), - table_for_numbers(vec![4]), - ), - ( - Expression::column("number").ne(4i64), + column_expr!("number").ne(4i64), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ]; @@ -695,27 +692,27 @@ fn predicate_on_number() -> Result<(), Box> { fn predicate_on_number_not() -> Result<(), Box> { let cases = vec![ ( - Expression::not(Expression::column("number").lt(4i64)), + Expression::not(column_expr!("number").lt(4i64)), table_for_numbers(vec![4, 5, 6]), ), ( - Expression::not(Expression::column("number").le(4i64)), + Expression::not(column_expr!("number").le(4i64)), table_for_numbers(vec![5, 6]), ), ( - Expression::not(Expression::column("number").gt(4i64)), + Expression::not(column_expr!("number").gt(4i64)), table_for_numbers(vec![1, 2, 3, 4]), ), ( - Expression::not(Expression::column("number").ge(4i64)), + Expression::not(column_expr!("number").ge(4i64)), table_for_numbers(vec![1, 2, 3]), ), ( - Expression::not(Expression::column("number").eq(4i64)), + Expression::not(column_expr!("number").eq(4i64)), table_for_numbers(vec![1, 2, 3, 5, 6]), ), ( - Expression::not(Expression::column("number").ne(4i64)), + Expression::not(column_expr!("number").ne(4i64)), table_for_numbers(vec![4]), ), ]; @@ -744,8 +741,8 @@ fn predicate_on_number_with_not_null() -> Result<(), Box> "./tests/data/basic_partitioned", Some(&["a_float", "number"]), Some(Expression::and( - Expression::column("number").is_not_null(), - Expression::column("number").lt(Expression::literal(3i64)), + column_expr!("number").is_not_null(), + column_expr!("number").lt(Expression::literal(3i64)), )), expected, )?; @@ -758,7 +755,7 @@ fn predicate_null() -> Result<(), Box> { read_table_data_str( "./tests/data/basic_partitioned", Some(&["a_float", "number"]), - Some(Expression::column("number").is_null()), + Some(column_expr!("number").is_null()), expected, )?; Ok(()) @@ -785,7 +782,7 @@ fn mixed_null() -> Result<(), Box> { read_table_data_str( "./tests/data/mixed-nulls", Some(&["part", "n"]), - Some(Expression::column("n").is_null()), + Some(column_expr!("n").is_null()), expected, )?; Ok(()) @@ -812,7 +809,7 @@ fn mixed_not_null() -> Result<(), Box> { read_table_data_str( "./tests/data/mixed-nulls", Some(&["part", "n"]), - Some(Expression::column("n").is_not_null()), + Some(column_expr!("n").is_not_null()), expected, )?; Ok(()) @@ -822,27 +819,27 @@ fn mixed_not_null() -> Result<(), Box> { fn and_or_predicates() -> Result<(), Box> { let cases = vec![ ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::column("a_float").gt(5.5)), + .and(column_expr!("a_float").gt(5.5)), table_for_numbers(vec![6]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::not(Expression::column("a_float").gt(5.5))), + .and(Expression::not(column_expr!("a_float").gt(5.5))), table_for_numbers(vec![5]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::column("a_float").gt(5.5)), + .or(column_expr!("a_float").gt(5.5)), table_for_numbers(vec![5, 6]), ), ( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::not(Expression::column("a_float").gt(5.5))), + .or(Expression::not(column_expr!("a_float").gt(5.5))), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ]; @@ -862,33 +859,33 @@ fn not_and_or_predicates() -> Result<(), Box> { let cases = vec![ ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::column("a_float").gt(5.5)), + .and(column_expr!("a_float").gt(5.5)), ), table_for_numbers(vec![1, 2, 3, 4, 5]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .and(Expression::not(Expression::column("a_float").gt(5.5))), + .and(Expression::not(column_expr!("a_float").gt(5.5))), ), table_for_numbers(vec![1, 2, 3, 4, 6]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::column("a_float").gt(5.5)), + .or(column_expr!("a_float").gt(5.5)), ), table_for_numbers(vec![1, 2, 3, 4]), ), ( Expression::not( - Expression::column("number") + column_expr!("number") .gt(4i64) - .or(Expression::not(Expression::column("a_float").gt(5.5))), + .or(Expression::not(column_expr!("a_float").gt(5.5))), ), vec![], ), @@ -913,23 +910,23 @@ fn invalid_skips_none_predicates() -> Result<(), Box> { table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").distinct(3i64), + column_expr!("number").distinct(3i64), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").gt(empty_struct.clone()), + column_expr!("number").gt(empty_struct.clone()), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::column("number").and(empty_struct.clone().is_null()), + column_expr!("number").and(empty_struct.clone().is_null()), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(Expression::column("number").gt(empty_struct.clone())), + Expression::not(column_expr!("number").gt(empty_struct.clone())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ( - Expression::not(Expression::column("number").and(empty_struct.clone().is_null())), + Expression::not(column_expr!("number").and(empty_struct.clone().is_null())), table_for_numbers(vec![1, 2, 3, 4, 5, 6]), ), ]; @@ -963,7 +960,7 @@ fn with_predicate_and_removes() -> Result<(), Box> { read_table_data_str( "./tests/data/table-with-dv-small/", None, - Some(Expression::gt(Expression::column("value"), 3)), + Some(Expression::gt(column_expr!("value"), 3)), expected, )?; Ok(())