Skip to content

Commit

Permalink
fix: disable datafusion optimizers which perform list field name erasure
Browse files Browse the repository at this point in the history
Today the make_array function from Datafusion uses "item" as the list
element's field name. With recent changes in delta-kernel-rs we have
switched to calling it "element" which is more conventional related to
how Apache Parquet handles things

This change introduces a test which helps isolate the behavior seen in
Python tests within the core crate for easier regression testing

Signed-off-by: R. Tyler Croy <[email protected]>
  • Loading branch information
rtyler committed Nov 21, 2024
1 parent 9c46985 commit b144cf1
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 24 deletions.
176 changes: 161 additions & 15 deletions crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,171 @@
use std::fmt::{self, Display, Error, Formatter, Write};
use std::sync::Arc;

use arrow_schema::DataType;
use arrow_array::{Array, GenericListArray};
use arrow_schema::{DataType, Field};
use chrono::{DateTime, NaiveDate};
use datafusion::execution::context::SessionState;
use datafusion::execution::session_state::SessionStateBuilder;
use datafusion::execution::FunctionRegistry;
use datafusion::functions_array::make_array::MakeArray;
use datafusion_common::Result as DFResult;
use datafusion_common::{config::ConfigOptions, DFSchema, Result, ScalarValue, TableReference};
use datafusion_expr::expr::InList;
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, Between, BinaryExpr, Cast, Expr, Like, TableSource};
// Needed for MakeParquetArray
use datafusion_expr::{ColumnarValue, Documentation, ScalarUDF, ScalarUDFImpl, Signature};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::escape_quoted_string;
use datafusion_sql::sqlparser::dialect::GenericDialect;
use datafusion_sql::sqlparser::parser::Parser;
use datafusion_sql::sqlparser::tokenizer::Tokenizer;
use tracing::log::*;

use super::DeltaParserOptions;
use crate::{DeltaResult, DeltaTableError};

/// This struct is like Datafusion's MakeArray but ensures that `element` is used rather than `item
/// as the field name within the list.
#[derive(Debug)]
struct MakeParquetArray {
/// The actual upstream UDF, which we're just totally cheating and using
actual: MakeArray,
/// Aliases for this UDF
aliases: Vec<String>,
}

impl MakeParquetArray {
pub fn new() -> Self {
let actual = MakeArray::default();
let aliases = vec!["make_array".into(), "make_list".into()];
Self { actual, aliases }
}
}

impl ScalarUDFImpl for MakeParquetArray {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"make_parquet_array"
}

fn signature(&self) -> &Signature {
self.actual.signature()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
let r_type = match arg_types.len() {
0 => Ok(DataType::List(Arc::new(Field::new(
"element",
DataType::Int32,
true,
)))),
_ => {
// At this point, all the type in array should be coerced to the same one
Ok(DataType::List(Arc::new(Field::new(
"element",
arg_types[0].to_owned(),
true,
))))
}
};
debug!("MakeParquetArray return_type -> {r_type:?}");
r_type
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let mut data_type = DataType::Null;
for arg in args {
data_type = arg.data_type();
}

match self.actual.invoke(args)? {
ColumnarValue::Scalar(ScalarValue::List(df_array)) => {
let field = Arc::new(Field::new("element", data_type, true));
let result = Ok(ColumnarValue::Scalar(ScalarValue::List(Arc::new(
GenericListArray::<i32>::try_new(
field,
df_array.offsets().clone(),
arrow_array::make_array(df_array.values().into_data()),
None,
)?,
))));
debug!("MakeParquetArray;invoke returning: {result:?}");
result
}
others => {
error!("Unexpected response inside MakeParquetArray! {others:?}");
Ok(others)
}
}
}

fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
self.actual.invoke_no_args(number_rows)
}

fn aliases(&self) -> &[String] {
&self.aliases
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
self.actual.coerce_types(arg_types)
}

fn documentation(&self) -> Option<&Documentation> {
self.actual.documentation()
}
}

use datafusion::functions_array::planner::{FieldAccessPlanner, NestedFunctionPlanner};

/// This exists becxause the NestedFunctionPlanner _not_ the UserDefinedFunctionPlanner handles the
/// insertion of "make_array" which is used to turn [100] into List<field=element, values=[100]>
///
/// **screaming intensifies**
#[derive(Debug)]
struct CustomNestedFunctionPlanner {
original: NestedFunctionPlanner,
}

impl Default for CustomNestedFunctionPlanner {
fn default() -> Self {
Self {
original: NestedFunctionPlanner,
}
}
}

use datafusion_expr::planner::{PlannerResult, RawBinaryExpr};
impl ExprPlanner for CustomNestedFunctionPlanner {
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
_schema: &DFSchema,
) -> Result<PlannerResult<Vec<Expr>>> {
let udf = Arc::new(ScalarUDF::from(MakeParquetArray::new()));

Ok(PlannerResult::Planned(udf.call(exprs)))
}
fn plan_binary_op(
&self,
expr: RawBinaryExpr,
schema: &DFSchema,
) -> Result<PlannerResult<RawBinaryExpr>> {
self.original.plan_binary_op(expr, schema)
}
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
self.original.plan_make_map(args)
}
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
self.original.plan_any(expr)
}
}

pub(crate) struct DeltaContextProvider<'a> {
state: SessionState,
/// Keeping this around just to make use of the 'a lifetime
Expand All @@ -51,22 +197,22 @@ pub(crate) struct DeltaContextProvider<'a> {

impl<'a> DeltaContextProvider<'a> {
fn new(state: &'a SessionState) -> Self {
let planners = state.expr_planners();
// default planners are [CoreFunctionPlanner, NestedFunctionPlanner, FieldAccessPlanner,
// UserDefinedFunctionPlanner]
let planners: Vec<Arc<dyn ExprPlanner>> = vec![
Arc::new(CoreFunctionPlanner::default()),
Arc::new(CustomNestedFunctionPlanner::default()),
Arc::new(FieldAccessPlanner),
Arc::new(datafusion::functions::planner::UserDefinedFunctionPlanner),
];
// Disable the above for testing
//let planners = state.expr_planners();
let new_state = SessionStateBuilder::new_from_existing(state.clone())
.with_expr_planners(planners.clone())
.build();
DeltaContextProvider {
planners,
// Creating a new session state with overridden scalar_functions since
// the get_field() UDF was dropped from the default scalar functions upstream in
// `36660fe10d9c0cdff62e0da0b94bee28422d3419`
state: SessionStateBuilder::new_from_existing(state.clone())
.with_scalar_functions(
state
.scalar_functions()
.values()
.cloned()
.chain(std::iter::once(datafusion::functions::core::get_field()))
.collect(),
)
.build(),
state: new_state,
_original: state,
}
}
Expand Down
34 changes: 28 additions & 6 deletions crates/core/src/operations/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ mod tests {
fn test_merge_arrow_schema_with_nested() {
let left_schema = Arc::new(Schema::new(vec![Field::new(
"f",
DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, false))),
DataType::LargeList(Arc::new(Field::new("element", DataType::Utf8, false))),
false,
)]));
let right_schema = Arc::new(Schema::new(vec![Field::new(
"f",
DataType::List(Arc::new(Field::new("item", DataType::LargeUtf8, false))),
DataType::List(Arc::new(Field::new("element", DataType::LargeUtf8, false))),
true,
)]));

Expand All @@ -306,7 +306,7 @@ mod tests {

let fields = Fields::from(vec![Field::new_list(
"list_column",
Field::new("item", DataType::Int8, false),
Field::new("element", DataType::Int8, false),
false,
)]);
let target_schema = Arc::new(Schema::new(fields)) as SchemaRef;
Expand All @@ -316,7 +316,7 @@ mod tests {
let schema = result.unwrap().schema();
let field = schema.column_with_name("list_column").unwrap().1;
if let DataType::List(list_item) = field.data_type() {
assert_eq!(list_item.name(), "item");
assert_eq!(list_item.name(), "element");
} else {
panic!("Not a list");
}
Expand All @@ -343,12 +343,34 @@ mod tests {

#[test]
fn test_is_cast_required_with_list() {
let field1 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));
let field1 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));
let field2 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));

assert!(!is_cast_required(&field1, &field2));
}

/// Delta has adopted "element" as the default list field name rather than the previously used
/// "item". This lines up more with Apache Parquet but should be handled in casting
#[test]
fn test_is_cast_required_with_old_and_new_list() {
let field1 = DataType::List(FieldRef::from(Field::new(
"element",
DataType::Int32,
false,
)));
let field2 = DataType::List(FieldRef::from(Field::new("item", DataType::Int32, false)));

assert!(is_cast_required(&field1, &field2));
}

#[test]
fn test_is_cast_required_with_smol_int() {
assert!(is_cast_required(&DataType::Int8, &DataType::Int32));
Expand Down
Loading

0 comments on commit b144cf1

Please sign in to comment.