Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chore] Use helpers to simplify expression handling code #399

Merged
merged 7 commits into from
Oct 18, 2024
15 changes: 5 additions & 10 deletions ffi/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::{
ReferenceSet, TryFromStringSlice,
};
use delta_kernel::{
expressions::{BinaryOperator, Expression, Scalar, UnaryOperator},
expressions::{BinaryOperator, Expression, UnaryOperator},
DeltaResult,
};

Expand Down Expand Up @@ -56,12 +56,10 @@ fn visit_expression_binary(
a: usize,
b: usize,
) -> usize {
let left = unwrap_kernel_expression(state, a).map(Box::new);
let right = unwrap_kernel_expression(state, b).map(Box::new);
let left = unwrap_kernel_expression(state, a);
let right = unwrap_kernel_expression(state, b);
match left.zip(right) {
Some((left, right)) => {
wrap_expression(state, Expression::BinaryOperation { op, left, right })
}
Some((left, right)) => wrap_expression(state, Expression::binary(op, left, right)),
None => 0, // invalid child => invalid node
}
}
Expand Down Expand Up @@ -182,10 +180,7 @@ fn visit_expression_literal_string_impl(
state: &mut KernelExpressionVisitorState,
value: DeltaResult<String>,
) -> DeltaResult<usize> {
Ok(wrap_expression(
state,
Expression::Literal(Scalar::from(value?)),
))
Ok(wrap_expression(state, Expression::literal(value?)))
}

// We need to get parse.expand working to be able to macro everything below, see issue #255
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/actions/deletion_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl DeletionVectorDescriptor {
let path_len = self.path_or_inline_dv.len();
require!(
path_len >= 20,
Error::deletion_vector("Invalid length {path_len}, must be >= 20",)
Error::deletion_vector("Invalid length {path_len}, must be >= 20")
);
let prefix_len = path_len - 20;
let decoded = z85::decode(&self.path_or_inline_dv[prefix_len..])
Expand Down
2 changes: 1 addition & 1 deletion kernel/src/actions/visitors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ mod tests {
app_id: "myApp2".to_string(),
version: 4,
last_updated: Some(1670892998177),
},)
})
);
assert_eq!(
actual.remove("myApp"),
Expand Down
26 changes: 10 additions & 16 deletions kernel/src/engine/arrow_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,27 +213,21 @@ impl TryFrom<&ArrowDataType> for DataType {
ArrowDataType::Struct(fields) => {
DataType::try_struct_type(fields.iter().map(|field| field.as_ref().try_into()))
}
ArrowDataType::List(field) => Ok(DataType::Array(Box::new(ArrayType::new(
(*field).data_type().try_into()?,
(*field).is_nullable(),
)))),
ArrowDataType::LargeList(field) => Ok(DataType::Array(Box::new(ArrayType::new(
(*field).data_type().try_into()?,
(*field).is_nullable(),
)))),
ArrowDataType::FixedSizeList(field, _) => Ok(DataType::Array(Box::new(
ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()),
))),
ArrowDataType::List(field) => {
Ok(ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()).into())
}
ArrowDataType::LargeList(field) => {
Ok(ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()).into())
}
ArrowDataType::FixedSizeList(field, _) => {
Ok(ArrayType::new((*field).data_type().try_into()?, (*field).is_nullable()).into())
}
ArrowDataType::Map(field, _) => {
if let ArrowDataType::Struct(struct_fields) = field.data_type() {
let key_type = DataType::try_from(struct_fields[0].data_type())?;
let value_type = DataType::try_from(struct_fields[1].data_type())?;
let value_type_nullable = struct_fields[1].is_nullable();
Ok(DataType::Map(Box::new(MapType::new(
key_type,
value_type,
value_type_nullable,
))))
Ok(MapType::new(key_type, value_type, value_type_nullable).into())
} else {
panic!("DataType::Map should contain a struct field child");
}
Expand Down
76 changes: 26 additions & 50 deletions kernel/src/engine/arrow_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,17 +431,9 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);
let not_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));

let in_op = Expression::binary(
BinaryOperator::In,
Expression::literal(5),
Expression::column("item"),
);
let in_op = Expression::binary(BinaryOperator::In, 5, Expression::column("item"));

let result = evaluate_expression(&not_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, false, true]);
Expand All @@ -459,11 +451,7 @@ mod tests {
let schema = Schema::new([field.clone()]);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(values.clone())]).unwrap();

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::column("item"),
);
let in_op = Expression::binary(BinaryOperator::NotIn, 5, Expression::column("item"));

let in_result = evaluate_expression(&in_op, &batch, None);

Expand All @@ -482,11 +470,11 @@ mod tests {

let in_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal(5),
Expression::literal(Scalar::Array(ArrayData::new(
5,
Scalar::Array(ArrayData::new(
ArrayType::new(DeltaDataTypes::INTEGER, false),
vec![Scalar::Integer(1), Scalar::Integer(2)],
))),
)),
);

let in_result = evaluate_expression(&in_op, &batch, None).unwrap();
Expand Down Expand Up @@ -533,17 +521,10 @@ mod tests {
let array = ListArray::new(field.clone(), offsets, Arc::new(values), None);
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())]).unwrap();

let str_not_op = Expression::binary(
BinaryOperator::NotIn,
Expression::literal("bye"),
Expression::column("item"),
);
let str_not_op =
Expression::binary(BinaryOperator::NotIn, "bye", Expression::column("item"));

let str_in_op = Expression::binary(
BinaryOperator::In,
Expression::literal("hi"),
Expression::column("item"),
);
let str_in_op = Expression::binary(BinaryOperator::In, "hi", Expression::column("item"));

let result = evaluate_expression(&str_in_op, &batch, None).unwrap();
let expected = BooleanArray::from(vec![true, true, true]);
Expand Down Expand Up @@ -593,23 +574,23 @@ mod tests {
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");

let expression = Box::new(column.clone().add(Expression::Literal(Scalar::Integer(1))));
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why these tests were boxing everything... but it's not needed now.

let expression = column.clone().add(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![2, 3, 4]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().sub(Expression::Literal(Scalar::Integer(1))));
let expression = column.clone().sub(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![0, 1, 2]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().mul(Expression::Literal(Scalar::Integer(2))));
let expression = column.clone().mul(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![2, 4, 6]));
assert_eq!(results.as_ref(), expected.as_ref());

// TODO handle type casting
let expression = Box::new(column.div(Expression::Literal(Scalar::Integer(1))));
let expression = column.div(1);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![1, 2, 3]));
assert_eq!(results.as_ref(), expected.as_ref())
Expand All @@ -630,17 +611,17 @@ mod tests {
let column_a = Expression::column("a");
let column_b = Expression::column("b");

let expression = Box::new(column_a.clone().add(column_b.clone()));
let expression = column_a.clone().add(column_b.clone());
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![2, 4, 6]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column_a.clone().sub(column_b.clone()));
let expression = column_a.clone().sub(column_b.clone());
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![0, 0, 0]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column_a.clone().mul(column_b));
let expression = column_a.clone().mul(column_b);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(Int32Array::from(vec![1, 4, 9]));
assert_eq!(results.as_ref(), expected.as_ref());
Expand All @@ -652,34 +633,33 @@ mod tests {
let values = Int32Array::from(vec![1, 2, 3]);
let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(values)]).unwrap();
let column = Expression::column("a");
let lit = Expression::Literal(Scalar::Integer(2));

let expression = Box::new(column.clone().lt(lit.clone()));
let expression = column.clone().lt(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![true, false, false]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().lt_eq(lit.clone()));
let expression = column.clone().lt_eq(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![true, true, false]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().gt(lit.clone()));
let expression = column.clone().gt(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![false, false, true]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().gt_eq(lit.clone()));
let expression = column.clone().gt_eq(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![false, true, true]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().eq(lit.clone()));
let expression = column.clone().eq(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![false, true, false]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column.clone().ne(lit.clone()));
let expression = column.clone().ne(2);
let results = evaluate_expression(&expression, &batch, None).unwrap();
let expected = Arc::new(BooleanArray::from(vec![true, false, true]));
assert_eq!(results.as_ref(), expected.as_ref());
Expand All @@ -702,32 +682,28 @@ mod tests {
let column_a = Expression::column("a");
let column_b = Expression::column("b");

let expression = Box::new(column_a.clone().and(column_b.clone()));
let expression = column_a.clone().and(column_b.clone());
let results =
evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let expected = Arc::new(BooleanArray::from(vec![false, false]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column_a.clone().and(Expression::literal(true)));
let expression = column_a.clone().and(true);
let results =
evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let expected = Arc::new(BooleanArray::from(vec![true, false]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(column_a.clone().or(column_b));
let expression = column_a.clone().or(column_b);
let results =
evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
let expected = Arc::new(BooleanArray::from(vec![true, true]));
assert_eq!(results.as_ref(), expected.as_ref());

let expression = Box::new(
column_a
.clone()
.or(Expression::literal(Scalar::Boolean(false))),
);
let expression = column_a.clone().or(false);
let results =
evaluate_expression(&expression, &batch, Some(&crate::schema::DataType::BOOLEAN))
.unwrap();
Expand Down
6 changes: 1 addition & 5 deletions kernel/src/engine/arrow_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -952,11 +952,7 @@ mod tests {
fn mask_with_map() {
let requested_schema = Arc::new(StructType::new([StructField::new(
"map",
DataType::Map(Box::new(MapType::new(
DataType::INTEGER,
DataType::STRING,
false,
))),
MapType::new(DataType::INTEGER, DataType::STRING, false),
false,
)]));
let parquet_schema = Arc::new(ArrowSchema::new(vec![ArrowField::new_map(
Expand Down
36 changes: 18 additions & 18 deletions kernel/src/engine/parquet_stats_skipping/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,76 +854,76 @@ fn test_sql_where() {
// Constrast normal vs SQL WHERE semantics - comparison inside AND
expect_eq!(
AllNullTestFilter.apply_expr(
&Expression::and_from([NULL, Expression::lt(col.clone(), val.clone()),]),
&Expression::and(NULL, Expression::lt(col.clone(), val.clone())),
false
),
None,
"{NULL} AND {col} < {val}"
);
expect_eq!(
AllNullTestFilter.apply_sql_where(&Expression::and_from([
AllNullTestFilter.apply_sql_where(&Expression::and(
NULL,
Expression::lt(col.clone(), val.clone()),
])),
)),
Some(false),
"WHERE {NULL} AND {col} < {val}"
);

expect_eq!(
AllNullTestFilter.apply_expr(
&Expression::and_from([TRUE, Expression::lt(col.clone(), val.clone()),]),
&Expression::and(TRUE, Expression::lt(col.clone(), val.clone())),
false
),
None, // NULL (from the NULL check) is stronger than TRUE
"{TRUE} AND {col} < {val}"
);
expect_eq!(
AllNullTestFilter.apply_sql_where(&Expression::and_from([
AllNullTestFilter.apply_sql_where(&Expression::and(
TRUE,
Expression::lt(col.clone(), val.clone()),
])),
)),
Some(false), // FALSE (from the NULL check) is stronger than TRUE
"WHERE {TRUE} AND {col} < {val}"
);

// Contrast normal vs. SQL WHERE semantics - comparison inside AND inside AND
expect_eq!(
AllNullTestFilter.apply_expr(
&Expression::and_from([
&Expression::and(
TRUE,
Expression::and_from([NULL, Expression::lt(col.clone(), val.clone()),]),
]),
Expression::and(NULL, Expression::lt(col.clone(), val.clone())),
),
false,
),
None,
"{TRUE} AND ({NULL} AND {col} < {val})"
);
expect_eq!(
AllNullTestFilter.apply_sql_where(&Expression::and_from([
AllNullTestFilter.apply_sql_where(&Expression::and(
TRUE,
Expression::and_from([NULL, Expression::lt(col.clone(), val.clone()),]),
])),
Expression::and(NULL, Expression::lt(col.clone(), val.clone())),
)),
Some(false),
"WHERE {TRUE} AND ({NULL} AND {col} < {val})"
);

// Semantics are the same for comparison inside OR inside AND
expect_eq!(
AllNullTestFilter.apply_expr(
&Expression::or_from([
&Expression::or(
FALSE,
Expression::and_from([NULL, Expression::lt(col.clone(), val.clone()),]),
]),
Expression::and(NULL, Expression::lt(col.clone(), val.clone())),
),
false,
),
None,
"{FALSE} OR ({NULL} AND {col} < {val})"
);
expect_eq!(
AllNullTestFilter.apply_sql_where(&Expression::or_from([
AllNullTestFilter.apply_sql_where(&Expression::or(
FALSE,
Expression::and_from([NULL, Expression::lt(col.clone(), val.clone()),]),
])),
Expression::and(NULL, Expression::lt(col.clone(), val.clone())),
)),
None,
"WHERE {FALSE} OR ({NULL} AND {col} < {val})"
);
Expand Down
Loading
Loading