Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expression::Column references a ColumnName object instead of a String #400

Merged
merged 12 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions derive-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Meta, PathArguments, Type,
};

/// Parses a dot-delimited column name into an array of field names. See
/// [`delta_kernel::expressions::column_name::column_name`] macro for details.
#[proc_macro]
pub fn parse_column_name(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let is_valid = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.';
let err = match syn::parse(input) {
Ok(syn::Lit::Str(name)) => match name.value().chars().find(|c| !is_valid(*c)) {
Some(bad_char) => Error::new(name.span(), format!("Invalid character: {bad_char:?}")),
_ => {
let path = name.value();
let path = path.split('.').map(proc_macro2::Literal::string);
return quote_spanned! { name.span() => [#(#path),*] }.into();
}
},
Ok(lit) => Error::new(lit.span(), "Expected a string literal"),
Err(err) => err,
};
err.into_compile_error().into()
}

/// Derive a `delta_kernel::schemas::ToDataType` implementation for the annotated struct. The actual
/// field names in the schema (and therefore of the struct members) are all mandated by the Delta
/// spec, and so the user of this macro is responsible for ensuring that
Expand Down
6 changes: 4 additions & 2 deletions ffi/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
ReferenceSet, TryFromStringSlice,
};
use delta_kernel::{
expressions::{BinaryOperator, Expression, UnaryOperator},
expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator},
DeltaResult,
};

Expand Down Expand Up @@ -146,7 +146,9 @@ fn visit_expression_column_impl(
state: &mut KernelExpressionVisitorState,
name: DeltaResult<String>,
) -> DeltaResult<usize> {
Ok(wrap_expression(state, Expression::Column(name?)))
// TODO: FIXME: This is incorrect if any field name in the column path contains a period.
let name = ColumnName::new(name?.split('.')).into();
zachschuermann marked this conversation as resolved.
Show resolved Hide resolved
Ok(wrap_expression(state, name))
}

#[no_mangle]
Expand Down
5 changes: 3 additions & 2 deletions kernel/src/actions/set_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::sync::{Arc, LazyLock};

use crate::actions::visitors::SetTransactionVisitor;
use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME};
use crate::expressions::column_expr;
use crate::snapshot::Snapshot;
use crate::{DeltaResult, Engine, EngineData, Expression, ExpressionRef, SchemaRef};
use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef};

pub use crate::actions::visitors::SetTransactionMap;
pub struct SetTransactionScanner {
Expand Down Expand Up @@ -53,7 +54,7 @@ impl SetTransactionScanner {
// point filtering by a particular app id, even if we have one, because app ids are all in
// the a single checkpoint part having large min/max range (because they're usually uuids).
static META_PREDICATE: LazyLock<Option<ExpressionRef>> =
LazyLock::new(|| Some(Arc::new(Expression::column("txn.appId").is_not_null())));
LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null())));
self.snapshot
.log_segment
.replay(engine, schema.clone(), schema, META_PREDICATE.clone())
Expand Down
31 changes: 15 additions & 16 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item"));

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
Expand All @@ -609,7 +609,7 @@ mod tests {
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_result = evaluate_expression(&in_op, &batch, None);

Expand Down Expand Up @@ -654,8 +654,8 @@ mod tests {

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::column("item"),
Expression::column("item"),
column_expr!("item"),
column_expr!("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);
Expand All @@ -679,10 +679,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op =
Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item"));
let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item"));

let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item"));
let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item"));

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
Expand All @@ -699,7 +698,7 @@ mod tests {
let values = Int32Array::from(vec![1, 2, 3]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
Expand All @@ -720,7 +719,7 @@ mod tests {
vec![Arc::new(struct_array.clone())],
)
.unwrap();
let column = Expression::column("b.a");
let column = column_expr!("b.a");
let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
}
Expand All @@ -730,7 +729,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().add(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -766,8 +765,8 @@ mod tests {
vec![Arc::new(values.clone()), Arc::new(values)],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().add(column_b.clone());
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand All @@ -790,7 +789,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().lt(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -837,8 +836,8 @@ mod tests {
],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().and(column_b.clone());
let results =
Expand Down
31 changes: 16 additions & 15 deletions kernel/src/engine/parquet_row_group_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::expressions::column_expr;
use crate::Expression;
use parquet::arrow::arrow_reader::ArrowReaderMetadata;
use std::fs::File;
Expand Down Expand Up @@ -39,21 +40,21 @@ fn test_get_stat_values() {

// The expression doesn't matter -- it just needs to mention all the columns we care about.
let columns = Expression::and_from(vec![
Expression::column("varlen.utf8"),
Expression::column("numeric.ints.int64"),
Expression::column("numeric.ints.int32"),
Expression::column("numeric.ints.int16"),
Expression::column("numeric.ints.int8"),
Expression::column("numeric.floats.float32"),
Expression::column("numeric.floats.float64"),
Expression::column("bool"),
Expression::column("varlen.binary"),
Expression::column("numeric.decimals.decimal32"),
Expression::column("numeric.decimals.decimal64"),
Expression::column("numeric.decimals.decimal128"),
Expression::column("chrono.date32"),
Expression::column("chrono.timestamp"),
Expression::column("chrono.timestamp_ntz"),
column_expr!("varlen.utf8"),
column_expr!("numeric.ints.int64"),
column_expr!("numeric.ints.int32"),
column_expr!("numeric.ints.int16"),
column_expr!("numeric.ints.int8"),
column_expr!("numeric.floats.float32"),
column_expr!("numeric.floats.float64"),
column_expr!("bool"),
column_expr!("varlen.binary"),
column_expr!("numeric.decimals.decimal32"),
column_expr!("numeric.decimals.decimal64"),
column_expr!("numeric.decimals.decimal128"),
column_expr!("chrono.date32"),
column_expr!("chrono.timestamp"),
column_expr!("chrono.timestamp_ntz"),
]);
let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns);

Expand Down
12 changes: 6 additions & 6 deletions kernel/src/engine/parquet_stats_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use crate::expressions::{ArrayData, StructData};
use crate::expressions::{column_expr, ArrayData, StructData};
use crate::schema::ArrayType;
use crate::DataType;

Expand Down Expand Up @@ -337,7 +337,7 @@ fn test_binary_eq_ne() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -485,7 +485,7 @@ fn test_binary_lt_ge() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
expect_eq!(
Expand Down Expand Up @@ -585,7 +585,7 @@ fn test_binary_le_gt() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -736,7 +736,7 @@ impl ParquetStatsSkippingFilter for NullCountTestFilter {
fn test_not_null() {
use UnaryOperator::IsNull;

let col = &Expression::column("x");
let col = &column_expr!("x");
for inverted in [false, true] {
expect_eq!(
NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted),
Expand Down Expand Up @@ -809,7 +809,7 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter {

#[test]
fn test_sql_where() {
let col = &Expression::column("x");
let col = &column_expr!("x");
let val = &Expression::literal(1);
const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN));
const FALSE: Expression = Expression::Literal(Scalar::Boolean(false));
Expand Down
Loading
Loading