-
Notifications
You must be signed in to change notification settings - Fork 58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix some issues with arrow expression eval #401
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -147,41 +147,28 @@ impl ProvidesColumnByName for StructArray { | |
} | ||
} | ||
|
||
fn extract_column<'array, 'path>( | ||
array: &'array dyn ProvidesColumnByName, | ||
path_step: &str, | ||
remaining_path_steps: &mut impl Iterator<Item = &'path str>, | ||
) -> Result<&'array Arc<dyn Array>, ArrowError> { | ||
let child = array | ||
.column_by_name(path_step) | ||
.ok_or(ArrowError::SchemaError(format!( | ||
"No such field: {}", | ||
path_step, | ||
)))?; | ||
if let Some(next_path_step) = remaining_path_steps.next() { | ||
// This is not the last path step. Drill deeper. | ||
extract_column( | ||
column_as_struct(path_step, &Some(child))?, | ||
next_path_step, | ||
remaining_path_steps, | ||
) | ||
} else { | ||
// Last path step. Return it. | ||
Ok(child) | ||
fn extract_column<'a>( | ||
mut parent: &dyn ProvidesColumnByName, | ||
mut field_names: impl Iterator<Item = &'a str>, | ||
) -> DeltaResult<ArrayRef> { | ||
let Some(mut field_name) = field_names.next() else { | ||
return Err(ArrowError::SchemaError("Empty column path".to_string()))?; | ||
}; | ||
loop { | ||
let child = parent | ||
.column_by_name(field_name) | ||
.ok_or_else(|| ArrowError::SchemaError(format!("No such field: {field_name}")))?; | ||
field_name = match field_names.next() { | ||
Some(name) => name, | ||
None => return Ok(child.clone()), | ||
}; | ||
parent = child | ||
.as_any() | ||
.downcast_ref::<StructArray>() | ||
.ok_or_else(|| ArrowError::SchemaError(format!("Not a struct: {field_name}")))?; | ||
} | ||
} | ||
|
||
fn column_as_struct<'a>( | ||
name: &str, | ||
column: &Option<&'a Arc<dyn Array>>, | ||
) -> Result<&'a StructArray, ArrowError> { | ||
column | ||
.ok_or(ArrowError::SchemaError(format!("No such column: {}", name)))? | ||
.as_any() | ||
.downcast_ref::<StructArray>() | ||
.ok_or(ArrowError::SchemaError(format!("{} is not a struct", name))) | ||
} | ||
|
||
fn evaluate_expression( | ||
expression: &Expression, | ||
batch: &RecordBatch, | ||
|
@@ -191,20 +178,7 @@ fn evaluate_expression( | |
use Expression::*; | ||
match (expression, result_type) { | ||
(Literal(scalar), _) => Ok(scalar.to_array(batch.num_rows())?), | ||
(Column(name), _) => { | ||
// TODO properly handle nested columns | ||
// https://github.com/delta-incubator/delta-kernel-rs/issues/86 | ||
if name.contains('.') { | ||
let mut path = name.split('.'); | ||
// Safety: we know that the first path step exists, because we checked for '.' | ||
Ok(extract_column(batch, path.next().unwrap(), &mut path).cloned()?) | ||
} else { | ||
batch | ||
.column_by_name(name) | ||
.ok_or(Error::missing_column(name)) | ||
.cloned() | ||
} | ||
} | ||
(Column(name), _) => extract_column(batch, name.split('.')), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we still need to have a follow-up in order to properly handle nested columns (instead of just naive There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely. My WIP nested column code already handles it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, you refer to the deleted TODO comment. Reinstating it. |
||
(Struct(fields), Some(DataType::Struct(output_schema))) => { | ||
let columns = fields | ||
.iter() | ||
|
@@ -244,22 +218,15 @@ fn evaluate_expression( | |
}, | ||
_, | ||
) => match (left.as_ref(), right.as_ref()) { | ||
(Literal(_), Column(c)) => { | ||
let list_type = batch.column_by_name(c).map(|c| c.data_type()); | ||
if !matches!( | ||
list_type, | ||
Some(ArrowDataType::List(_)) | Some(ArrowDataType::FixedSizeList(_, _)) | ||
) { | ||
return Err(Error::InvalidExpressionEvaluation(format!( | ||
"Right side column: {c} is not a list or a fixed size list" | ||
))); | ||
} | ||
(Literal(_), Column(_)) => { | ||
let left_arr = evaluate_expression(left.as_ref(), batch, None)?; | ||
let right_arr = evaluate_expression(right.as_ref(), batch, None)?; | ||
if let Some(string_arr) = left_arr.as_string_opt::<i32>() { | ||
return in_list_utf8(string_arr, right_arr.as_list::<i32>()) | ||
.map(wrap_comparison_result) | ||
.map_err(Error::generic_err); | ||
if let Some(right_arr) = right_arr.as_list_opt::<i32>() { | ||
return in_list_utf8(string_arr, right_arr) | ||
.map(wrap_comparison_result) | ||
.map_err(Error::generic_err); | ||
} | ||
} | ||
prim_array_cmp! { | ||
left_arr, right_arr, | ||
|
@@ -454,7 +421,7 @@ mod tests { | |
); | ||
|
||
let in_op = Expression::binary( | ||
BinaryOperator::NotIn, | ||
BinaryOperator::In, | ||
Expression::literal(5), | ||
Expression::column("item"), | ||
); | ||
|
@@ -464,7 +431,7 @@ mod tests { | |
assert_eq!(result.as_ref(), &expected); | ||
|
||
let in_result = evaluate_expression(&in_op, &batch, None).unwrap(); | ||
let in_expected = BooleanArray::from(vec![true, false, true]); | ||
let in_expected = BooleanArray::from(vec![false, true, false]); | ||
assert_eq!(in_result.as_ref(), &in_expected); | ||
} | ||
|
||
|
@@ -486,8 +453,8 @@ mod tests { | |
assert!(in_result.is_err()); | ||
assert_eq!( | ||
in_result.unwrap_err().to_string(), | ||
"Invalid expression evaluation: Right side column: item is not a list or a fixed size list".to_string() | ||
) | ||
"Invalid expression evaluation: Cannot cast to list array: Int32" | ||
); | ||
} | ||
|
||
#[test] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a super quick example would also be useful? e.g. parent is complex struct column
a
and field names are[b, c]
means we return a ref to the nested columnc
? am I understanding correctly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly. Added the missing doc comment.