From 240402d28a2731e6e272f059dacfc1129d70c175 Mon Sep 17 00:00:00 2001 From: Dmitrii Blaginin Date: Thu, 21 Nov 2024 00:04:46 +0000 Subject: [PATCH] Coerce Array inner types (#13452) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Coerce Array inner types * `coerce_list_children` macro → fn * Rearrange & simplify patterns * Add casting for `FixedSizeList` * Add sql logic test * Fix test * Expand error * Switch to `Arc::clone` * OR nullable-s * Add successful test * Switch to `type_union_resolution` * Add a note on `coerce_list_children` * Handle different list types inside `type_union_resolution_coercion` * Add a `make_array` test with different array types * Add the test value * Update datafusion/sqllogictest/test_files/array.slt Co-authored-by: Jay Zhan * Reorder match --------- Co-authored-by: Jay Zhan --- .../expr-common/src/type_coercion/binary.rs | 73 +++++++++++++++---- datafusion/functions-nested/src/make_array.rs | 3 +- datafusion/sqllogictest/test_files/array.slt | 12 +++ datafusion/sqllogictest/test_files/union.slt | 16 ++++ 4 files changed, 86 insertions(+), 18 deletions(-) diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index c32b4951db44..bff74252df7b 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -480,11 +480,6 @@ fn type_union_resolution_coercion( let new_value_type = type_union_resolution_coercion(value_type, other_type); new_value_type.map(|t| DataType::Dictionary(index_type.clone(), Box::new(t))) } - (DataType::List(lhs), DataType::List(rhs)) => { - let new_item_type = - type_union_resolution_coercion(lhs.data_type(), rhs.data_type()); - new_item_type.map(|t| DataType::List(Arc::new(Field::new("item", t, true)))) - } (DataType::Struct(lhs), DataType::Struct(rhs)) => { if lhs.len() != rhs.len() { return None; @@ -529,6 +524,7 @@ fn type_union_resolution_coercion( // Numeric coercion is the same as comparison coercion, both find the narrowest type // that can accommodate both types binary_numeric_coercion(lhs_type, rhs_type) + .or_else(|| list_coercion(lhs_type, rhs_type)) .or_else(|| temporal_coercion_nonstrict_timezone(lhs_type, rhs_type)) .or_else(|| string_coercion(lhs_type, rhs_type)) .or_else(|| numeric_string_coercion(lhs_type, rhs_type)) @@ -1138,27 +1134,46 @@ fn numeric_string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option { + let data_types = vec![lhs_field.data_type().clone(), rhs_field.data_type().clone()]; + Some(Arc::new( + (**lhs_field) + .clone() + .with_data_type(type_union_resolution(&data_types)?) + .with_nullable(lhs_field.is_nullable() || rhs_field.is_nullable()), + )) +} + /// Coercion rules for list types. fn list_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { - (List(_), List(_)) => Some(lhs_type.clone()), - (LargeList(_), List(_)) => Some(lhs_type.clone()), - (List(_), LargeList(_)) => Some(rhs_type.clone()), - (LargeList(_), LargeList(_)) => Some(lhs_type.clone()), - (List(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), List(_)) => Some(rhs_type.clone()), // Coerce to the left side FixedSizeList type if the list lengths are the same, // otherwise coerce to list with the left type for dynamic length - (FixedSizeList(lf, ls), FixedSizeList(_, rs)) => { + (FixedSizeList(lhs_field, ls), FixedSizeList(rhs_field, rs)) => { if ls == rs { - Some(lhs_type.clone()) + Some(FixedSizeList( + coerce_list_children(lhs_field, rhs_field)?, + *rs, + )) } else { - Some(List(Arc::clone(lf))) + Some(List(coerce_list_children(lhs_field, rhs_field)?)) } } - (LargeList(_), FixedSizeList(_, _)) => Some(lhs_type.clone()), - (FixedSizeList(_, _), LargeList(_)) => Some(rhs_type.clone()), + // LargeList on any side + ( + LargeList(lhs_field), + List(rhs_field) | LargeList(rhs_field) | FixedSizeList(rhs_field, _), + ) + | (List(lhs_field) | FixedSizeList(lhs_field, _), LargeList(rhs_field)) => { + Some(LargeList(coerce_list_children(lhs_field, rhs_field)?)) + } + // Lists on both sides + (List(lhs_field), List(rhs_field) | FixedSizeList(rhs_field, _)) + | (FixedSizeList(lhs_field, _), List(rhs_field)) => { + Some(List(coerce_list_children(lhs_field, rhs_field)?)) + } _ => None, } } @@ -2105,10 +2120,36 @@ mod tests { DataType::List(Arc::clone(&inner_field)) ); + // Negative test: inner_timestamp_field and inner_field are not compatible because their inner types are not compatible + let inner_timestamp_field = Arc::new(Field::new( + "item", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + )); + let result_type = get_input_types( + &DataType::List(Arc::clone(&inner_field)), + &Operator::Eq, + &DataType::List(Arc::clone(&inner_timestamp_field)), + ); + assert!(result_type.is_err()); + // TODO add other data type Ok(()) } + #[test] + fn test_list_coercion() { + let lhs_type = DataType::List(Arc::new(Field::new("lhs", DataType::Int8, false))); + + let rhs_type = DataType::List(Arc::new(Field::new("rhs", DataType::Int64, true))); + + let coerced_type = list_coercion(&lhs_type, &rhs_type).unwrap(); + assert_eq!( + coerced_type, + DataType::List(Arc::new(Field::new("lhs", DataType::Int64, true))) + ); // nullable because the RHS is nullable + } + #[test] fn test_type_coercion_logical_op() -> Result<()> { test_coercion_binary_rule!( diff --git a/datafusion/functions-nested/src/make_array.rs b/datafusion/functions-nested/src/make_array.rs index de67b0ae3874..c84b6f010968 100644 --- a/datafusion/functions-nested/src/make_array.rs +++ b/datafusion/functions-nested/src/make_array.rs @@ -26,7 +26,7 @@ use arrow_array::{ new_null_array, Array, ArrayRef, GenericListArray, NullArray, OffsetSizeTrait, }; use arrow_buffer::OffsetBuffer; -use arrow_schema::DataType::{LargeList, List, Null}; +use arrow_schema::DataType::{List, Null}; use arrow_schema::{DataType, Field}; use datafusion_common::{plan_err, utils::array_into_list_array_nullable, Result}; use datafusion_expr::binary::{ @@ -198,7 +198,6 @@ pub(crate) fn make_array_inner(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Int64, length); Ok(Arc::new(array_into_list_array_nullable(array))) } - LargeList(..) => array_array::(arrays, data_type), _ => array_array::(arrays, data_type), } } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 1e60699a1f65..e6676d683f91 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1155,6 +1155,18 @@ select column1, column5 from arrays_values_without_nulls; [21, 22, 23, 24, 25, 26, 27, 28, 29, 30] [6, 7] [31, 32, 33, 34, 35, 26, 37, 38, 39, 40] [8, 9] +# make array with arrays of different types +query ? +select make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)')) +---- +[[1], [-1]] + +query T +select arrow_typeof(make_array(make_array(1), arrow_cast(make_array(-1), 'LargeList(Int8)'))); +---- +List(Field { name: "item", data_type: LargeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) + + query ??? select make_array(column1), make_array(column1, column5), diff --git a/datafusion/sqllogictest/test_files/union.slt b/datafusion/sqllogictest/test_files/union.slt index fb7afdda2ea8..b5e82f613a46 100644 --- a/datafusion/sqllogictest/test_files/union.slt +++ b/datafusion/sqllogictest/test_files/union.slt @@ -761,3 +761,19 @@ SELECT NULL WHERE FALSE; ---- 0.5 1 + +# Test Union of List Types. Issue: https://github.com/apache/datafusion/issues/12291 +query error DataFusion error: type_coercion\ncaused by\nError during planning: Incompatible inputs for Union: Previous inputs were of type List(.*), but got incompatible type List(.*) on column 'x' +SELECT make_array(2) x UNION ALL SELECT make_array(now()) x; + +query ? +select make_array(arrow_cast(2, 'UInt8')) x UNION ALL SELECT make_array(arrow_cast(-2, 'Int8')) x; +---- +[-2] +[2] + +query ? +select make_array(make_array(1)) x UNION ALL SELECT make_array(arrow_cast(make_array(-1), 'LargeList(Int8)')) x; +---- +[[-1]] +[[1]]