Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
zachschuermann committed Oct 23, 2024
2 parents f5530f9 + e48d238 commit 37db615
Show file tree
Hide file tree
Showing 14 changed files with 441 additions and 144 deletions.
20 changes: 20 additions & 0 deletions derive-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,26 @@ use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Error, Fields, Meta, PathArguments, Type,
};

/// Parses a dot-delimited column name into an array of field names. See
/// [`delta_kernel::expressions::column_name::column_name`] macro for details.
#[proc_macro]
pub fn parse_column_name(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
let is_valid = |c: char| c.is_ascii_alphanumeric() || c == '_' || c == '.';
let err = match syn::parse(input) {
Ok(syn::Lit::Str(name)) => match name.value().chars().find(|c| !is_valid(*c)) {
Some(bad_char) => Error::new(name.span(), format!("Invalid character: {bad_char:?}")),
_ => {
let path = name.value();
let path = path.split('.').map(proc_macro2::Literal::string);
return quote_spanned! { name.span() => [#(#path),*] }.into();
}
},
Ok(lit) => Error::new(lit.span(), "Expected a string literal"),
Err(err) => err,
};
err.into_compile_error().into()
}

/// Derive a `delta_kernel::schemas::ToDataType` implementation for the annotated struct. The actual
/// field names in the schema (and therefore of the struct members) are all mandated by the Delta
/// spec, and so the user of this macro is responsible for ensuring that
Expand Down
6 changes: 4 additions & 2 deletions ffi/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
ReferenceSet, TryFromStringSlice,
};
use delta_kernel::{
expressions::{BinaryOperator, Expression, UnaryOperator},
expressions::{BinaryOperator, ColumnName, Expression, UnaryOperator},
DeltaResult,
};

Expand Down Expand Up @@ -146,7 +146,9 @@ fn visit_expression_column_impl(
state: &mut KernelExpressionVisitorState,
name: DeltaResult<String>,
) -> DeltaResult<usize> {
Ok(wrap_expression(state, Expression::Column(name?)))
// TODO: FIXME: This is incorrect if any field name in the column path contains a period.
let name = ColumnName::new(name?.split('.')).into();
Ok(wrap_expression(state, name))
}

#[no_mangle]
Expand Down
5 changes: 3 additions & 2 deletions kernel/src/actions/set_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::sync::{Arc, LazyLock};

use crate::actions::visitors::SetTransactionVisitor;
use crate::actions::{get_log_schema, SetTransaction, SET_TRANSACTION_NAME};
use crate::expressions::column_expr;
use crate::snapshot::Snapshot;
use crate::{DeltaResult, Engine, EngineData, Expression, ExpressionRef, SchemaRef};
use crate::{DeltaResult, Engine, EngineData, ExpressionRef, SchemaRef};

pub use crate::actions::visitors::SetTransactionMap;
pub struct SetTransactionScanner {
Expand Down Expand Up @@ -53,7 +54,7 @@ impl SetTransactionScanner {
// point filtering by a particular app id, even if we have one, because app ids are all in
// the a single checkpoint part having large min/max range (because they're usually uuids).
static META_PREDICATE: LazyLock<Option<ExpressionRef>> =
LazyLock::new(|| Some(Arc::new(Expression::column("txn.appId").is_not_null())));
LazyLock::new(|| Some(Arc::new(column_expr!("txn.appId").is_not_null())));
self.snapshot
.log_segment
.replay(engine, schema.clone(), schema, META_PREDICATE.clone())
Expand Down
31 changes: 15 additions & 16 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -589,9 +589,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let not_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::In, 5, column_expr!("item"));

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
Expand All @@ -609,7 +609,7 @@ mod tests {
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));
let in_op = Expression::binary(BinaryOperator::NotIn, 5, column_expr!("item"));

let in_result = evaluate_expression(&in_op, &batch, None);

Expand Down Expand Up @@ -654,8 +654,8 @@ mod tests {

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::column("item"),
Expression::column("item"),
column_expr!("item"),
column_expr!("item"),
);

let in_result = evaluate_expression(&in_op, &batch, None);
Expand All @@ -679,10 +679,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op =
Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item"));
let str_not_op = Expression::binary(BinaryOperator::NotIn, "bye", column_expr!("item"));

let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item"));
let str_in_op = Expression::binary(BinaryOperator::In, "hi", column_expr!("item"));

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
Expand All @@ -699,7 +698,7 @@ mod tests {
let values = Int32Array::from(vec![1, 2, 3]);
let batch =
RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values.clone())]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
Expand All @@ -720,7 +719,7 @@ mod tests {
vec![Arc::new(struct_array.clone())],
)
.unwrap();
let column = Expression::column("b.a");
let column = column_expr!("b.a");
let results = evaluate_expression(&column, &batch, None).unwrap();
assert_eq!(results.as_ref(), &values);
}
Expand All @@ -730,7 +729,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().add(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -766,8 +765,8 @@ mod tests {
vec![Arc::new(values.clone()), Arc::new(values)],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().add(column_b.clone());
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand All @@ -790,7 +789,7 @@ mod tests {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let column = column_expr!("a");

let expression = column.clone().lt(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
Expand Down Expand Up @@ -837,8 +836,8 @@ mod tests {
],
)
.unwrap();
let column_a = Expression::column("a");
let column_b = Expression::column("b");
let column_a = column_expr!("a");
let column_b = column_expr!("b");

let expression = column_a.clone().and(column_b.clone());
let results =
Expand Down
31 changes: 16 additions & 15 deletions kernel/src/engine/parquet_row_group_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::expressions::column_expr;
use crate::Expression;
use parquet::arrow::arrow_reader::ArrowReaderMetadata;
use std::fs::File;
Expand Down Expand Up @@ -39,21 +40,21 @@ fn test_get_stat_values() {

// The expression doesn't matter -- it just needs to mention all the columns we care about.
let columns = Expression::and_from(vec![
Expression::column("varlen.utf8"),
Expression::column("numeric.ints.int64"),
Expression::column("numeric.ints.int32"),
Expression::column("numeric.ints.int16"),
Expression::column("numeric.ints.int8"),
Expression::column("numeric.floats.float32"),
Expression::column("numeric.floats.float64"),
Expression::column("bool"),
Expression::column("varlen.binary"),
Expression::column("numeric.decimals.decimal32"),
Expression::column("numeric.decimals.decimal64"),
Expression::column("numeric.decimals.decimal128"),
Expression::column("chrono.date32"),
Expression::column("chrono.timestamp"),
Expression::column("chrono.timestamp_ntz"),
column_expr!("varlen.utf8"),
column_expr!("numeric.ints.int64"),
column_expr!("numeric.ints.int32"),
column_expr!("numeric.ints.int16"),
column_expr!("numeric.ints.int8"),
column_expr!("numeric.floats.float32"),
column_expr!("numeric.floats.float64"),
column_expr!("bool"),
column_expr!("varlen.binary"),
column_expr!("numeric.decimals.decimal32"),
column_expr!("numeric.decimals.decimal64"),
column_expr!("numeric.decimals.decimal128"),
column_expr!("chrono.date32"),
column_expr!("chrono.timestamp"),
column_expr!("chrono.timestamp_ntz"),
]);
let filter = RowGroupFilter::new(metadata.metadata().row_group(0), &columns);

Expand Down
12 changes: 6 additions & 6 deletions kernel/src/engine/parquet_stats_skipping/tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use crate::expressions::{ArrayData, StructData};
use crate::expressions::{column_expr, ArrayData, StructData};
use crate::schema::ArrayType;
use crate::DataType;

Expand Down Expand Up @@ -337,7 +337,7 @@ fn test_binary_eq_ne() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -485,7 +485,7 @@ fn test_binary_lt_ge() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
expect_eq!(
Expand Down Expand Up @@ -585,7 +585,7 @@ fn test_binary_le_gt() {
const LO: Scalar = Scalar::Long(1);
const MID: Scalar = Scalar::Long(10);
const HI: Scalar = Scalar::Long(100);
let col = &Expression::column("x");
let col = &column_expr!("x");

for inverted in [false, true] {
// negative test -- mismatched column type
Expand Down Expand Up @@ -736,7 +736,7 @@ impl ParquetStatsSkippingFilter for NullCountTestFilter {
fn test_not_null() {
use UnaryOperator::IsNull;

let col = &Expression::column("x");
let col = &column_expr!("x");
for inverted in [false, true] {
expect_eq!(
NullCountTestFilter::new(None, 10).apply_unary(IsNull, col, inverted),
Expand Down Expand Up @@ -809,7 +809,7 @@ impl ParquetStatsSkippingFilter for AllNullTestFilter {

#[test]
fn test_sql_where() {
let col = &Expression::column("x");
let col = &column_expr!("x");
let val = &Expression::literal(1);
const NULL: Expression = Expression::Literal(Scalar::Null(DataType::BOOLEAN));
const FALSE: Expression = Expression::Literal(Scalar::Boolean(false));
Expand Down
Loading

0 comments on commit 37db615

Please sign in to comment.