Skip to content

Commit

Permalink
Expression::Column references a ColumnName object instead of a String (
Browse files Browse the repository at this point in the history
…#400)

Kernel only partly supports nested column names today. One of the
biggest barriers to closing the gap is that Expression tracks column
names as simple strings, so there is no reliable way to deal with
nesting.

This PR takes a first step of defining a proper `ColumnName` struct for
hosting column names. So far, it behaves the same as before --
internally it's just a simple string, occasionally interpreted as nested
by splitting at periods. Changing the code to use this new construct is
noisy, so we make the changes as a pre-factor. A follow-up PR will
actually add nesting support.

Resolves #422
  • Loading branch information
scovich authored Oct 23, 2024
1 parent b09078d commit e48d238
Show file tree
Hide file tree
Showing 14 changed files with 441 additions and 144 deletions.
20 changes: 20 additions & 0 deletions derive-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Meta, PathArguments, Type,
};

/// Parses a dot-delimited column name into an array of field names. See
/// [`delta_kernel::expressions::column_name::column_name`] macro for details.
#[proc_macro]
pub fn parse_column_name(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let is_valid = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.';
let err = match syn::parse(input) {
Ok(syn::Lit::Str(name)) => match name.value().chars().find(|c| !is_valid(*c)) {
Some(bad_char) => Error::new(name.span(), format!("Invalid character: {bad_char:?}")),
_ => {
let path = name.value();
let path = path.split('.').map(proc_macro2::Literal::string);
return quote_spanned! { name.span() => [#(#path),*] }.into();
}
},
Ok(lit) => Error::new(lit.span(), "Expected a string literal"),
Err(err) => err,
};
err.into_compile_error().into()
}

/// Derive a `delta_kernel::schemas::ToDataType` implementation for the annotated struct. The actual
/// field names in the schema (and therefore of the struct members) are all mandated by the Delta
/// spec, and so the user of this macro is responsible for ensuring that
Expand Down
6 changes: 4 additions & 2 deletions ffi/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
ReferenceSet, TryFromStringSlice,
};
use delta_kernel::{
expressions::{BinaryOperator, Expression, UnaryOperator},
expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator},
DeltaResult,
};

Expand Down Expand Up @@ -146,7 +146,9 @@ fn visit_expression_column_impl(
state: &mut KernelExpressionVisitorState,
name: DeltaResult<String>,
) -> DeltaResult<usize> {
Ok(wrap_expression(state, Expression::Column(name?)))
// TODO: FIXME: This is incorrect if any field name in the column path contains a period.
let name = ColumnName::new(name?.split('.')).into();
Ok(wrap_expression(state, name))
}

#[no_mangle]
Expand Down
5 changes: 3 additions & 2 deletions kernel/src/actions/set_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::sync::{Arc, LazyLock};

use crate::actions::visitors::SetTransactionVisitor;
use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME};
use crate::expressions::column_expr;
use crate::snapshot::Snapshot;
use crate::{DeltaResult, Engine, EngineData, Expression, ExpressionRef, SchemaRef};
use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef};

pub use crate::actions::visitors::SetTransactionMap;
pub struct SetTransactionScanner {
Expand Down Expand Up @@ -53,7 +54,7 @@ impl SetTransactionScanner {
// point filtering by a particular app id, even if we have one, because app ids are all in
// the a single checkpoint part having large min/max range (because they're usually uuids).
static META_PREDICATE: LazyLock<Option<ExpressionRef>> =
LazyLock::new(|| Some(Arc::new(Expression::column("txn.appId").is_not_null())));
LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null())));
self.snapshot
.log_segment
.replay(engine, schema.clone(), schema, META_PREDICATE.clone())
Expand Down
31 changes: 15 additions & 16 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item"));

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
Expand All @@ -609,7 +609,7 @@ mod tests {
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_result = evaluate_expression(&in_op, &batch, None);

Expand Down Expand Up @@ -654,8 +654,8 @@ mod tests {

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::column("item"),
Expression::column("item"),
column_expr!("item"),
column_expr!("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);
Expand All @@ -679,10 +679,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op =
Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item"));
let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item"));

let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item"));
let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item"));

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
Expand All @@ -699,7 +698,7 @@ mod tests {
let values = Int32Array::from(vec![1, 2, 3]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
Expand All @@ -720,7 +719,7 @@ mod tests {
vec![Arc::new(struct_array.clone())],
)
.unwrap();
let column = Expression::column("b.a");
let column = column_expr!("b.a");
let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
}
Expand All @@ -730,7 +729,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().add(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -766,8 +765,8 @@ mod tests {
vec![Arc::new(values.clone()), Arc::new(values)],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().add(column_b.clone());
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand All @@ -790,7 +789,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().lt(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -837,8 +836,8 @@ mod tests {
],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().and(column_b.clone());
let results =
Expand Down
31 changes: 16 additions & 15 deletions kernel/src/engine/parquet_row_group_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::expressions::column_expr;
use crate::Expression;
use parquet::arrow::arrow_reader::ArrowReaderMetadata;
use std::fs::File;
Expand Down Expand Up @@ -39,21 +40,21 @@ fn test_get_stat_values() {

// The expression doesn't matter -- it just needs to mention all the columns we care about.
let columns = Expression::and_from(vec![
Expression::column("varlen.utf8"),
Expression::column("numeric.ints.int64"),
Expression::column("numeric.ints.int32"),
Expression::column("numeric.ints.int16"),
Expression::column("numeric.ints.int8"),
Expression::column("numeric.floats.float32"),
Expression::column("numeric.floats.float64"),
Expression::column("bool"),
Expression::column("varlen.binary"),
Expression::column("numeric.decimals.decimal32"),
Expression::column("numeric.decimals.decimal64"),
Expression::column("numeric.decimals.decimal128"),
Expression::column("chrono.date32"),
Expression::column("chrono.timestamp"),
Expression::column("chrono.timestamp_ntz"),
column_expr!("varlen.utf8"),
column_expr!("numeric.ints.int64"),
column_expr!("numeric.ints.int32"),
column_expr!("numeric.ints.int16"),
column_expr!("numeric.ints.int8"),
column_expr!("numeric.floats.float32"),
column_expr!("numeric.floats.float64"),
column_expr!("bool"),
column_expr!("varlen.binary"),
column_expr!("numeric.decimals.decimal32"),
column_expr!("numeric.decimals.decimal64"),
column_expr!("numeric.decimals.decimal128"),
column_expr!("chrono.date32"),
column_expr!("chrono.timestamp"),
column_expr!("chrono.timestamp_ntz"),
]);
let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns);

Expand Down
12 changes: 6 additions & 6 deletions kernel/src/engine/parquet_stats_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use crate::expressions::{ArrayData, StructData};
use crate::expressions::{column_expr, ArrayData, StructData};
use crate::schema::ArrayType;
use crate::DataType;

Expand Down Expand Up @@ -337,7 +337,7 @@ fn test_binary_eq_ne() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -485,7 +485,7 @@ fn test_binary_lt_ge() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
expect_eq!(
Expand Down Expand Up @@ -585,7 +585,7 @@ fn test_binary_le_gt() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -736,7 +736,7 @@ impl ParquetStatsSkippingFilter for NullCountTestFilter {
fn test_not_null() {
use UnaryOperator::IsNull;

let col = &Expression::column("x");
let col = &column_expr!("x");
for inverted in [false, true] {
expect_eq!(
NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted),
Expand Down Expand Up @@ -809,7 +809,7 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter {

#[test]
fn test_sql_where() {
let col = &Expression::column("x");
let col = &column_expr!("x");
let val = &Expression::literal(1);
const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN));
const FALSE: Expression = Expression::Literal(Scalar::Boolean(false));
Expand Down
Loading

0 comments on commit e48d238

Please sign in to comment.