From 372da21769d2e5ca26a39e6695c2077b29257d48 Mon Sep 17 00:00:00 2001 From: Alex Huang Date: Sun, 3 Dec 2023 13:35:17 +0100 Subject: [PATCH] feat: support `LargeList` in `make_array` and `array_length` (#8121) * feat: support LargeList in make_array and array_length * chore: add tests * fix: update tests for nested array * use usise_as * add new_large_list * refactor array_length * add comment * update test in sqllogictest * fix ci * fix macro * use usize_as * update comment * return based on data_type in make_array --- .../physical-expr/src/array_expressions.rs | 47 +++++++++++++----- datafusion/sqllogictest/test_files/array.slt | 49 +++++++++++++++++++ 2 files changed, 83 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 84dfe3b9ff75..0601c22ecfb4 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -171,6 +171,10 @@ fn compute_array_length( value = downcast_arg!(value, ListArray).value(0); current_dimension += 1; } + DataType::LargeList(..) => { + value = downcast_arg!(value, LargeListArray).value(0); + current_dimension += 1; + } _ => return Ok(None), } } @@ -252,7 +256,7 @@ macro_rules! call_array_function { } /// Convert one or more [`ArrayRef`] of the same type into a -/// `ListArray` +/// `ListArray` or 'LargeListArray' depending on the offset size. /// /// # Example (non nested) /// @@ -291,7 +295,10 @@ macro_rules! call_array_function { /// └──────────────┘ └──────────────┘ └─────────────────────────────┘ /// col1 col2 output /// ``` -fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { +fn array_array( + args: &[ArrayRef], + data_type: DataType, +) -> Result { // do not accept 0 arguments. if args.is_empty() { return plan_err!("Array requires at least one argument"); @@ -308,8 +315,9 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { total_len += arg_data.len(); data.push(arg_data); } - let mut offsets = Vec::with_capacity(total_len); - offsets.push(0); + + let mut offsets: Vec = Vec::with_capacity(total_len); + offsets.push(O::usize_as(0)); let capacity = Capacities::Array(total_len); let data_ref = data.iter().collect::>(); @@ -327,11 +335,11 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { mutable.extend_nulls(1); } } - offsets.push(mutable.len() as i32); + offsets.push(O::usize_as(mutable.len())); } - let data = mutable.freeze(); - Ok(Arc::new(ListArray::try_new( + + Ok(Arc::new(GenericListArray::::try_new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::new(offsets.into()), arrow_array::make_array(data), @@ -356,7 +364,8 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { let array = new_null_array(&DataType::Null, arrays.len()); Ok(Arc::new(array_into_list_array(array))) } - data_type => array_array(arrays, data_type), + DataType::LargeList(..) => array_array::(arrays, data_type), + _ => array_array::(arrays, data_type), } } @@ -1693,11 +1702,11 @@ pub fn flatten(args: &[ArrayRef]) -> Result { Ok(Arc::new(flattened_array) as ArrayRef) } -/// Array_length SQL function -pub fn array_length(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let dimension = if args.len() == 2 { - as_int64_array(&args[1])?.clone() +/// Dispatch array length computation based on the offset type. +fn array_length_dispatch(array: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&array[0])?; + let dimension = if array.len() == 2 { + as_int64_array(&array[1])?.clone() } else { Int64Array::from_value(1, list_array.len()) }; @@ -1711,6 +1720,18 @@ pub fn array_length(args: &[ArrayRef]) -> Result { Ok(Arc::new(result) as ArrayRef) } +/// Array_length SQL function +pub fn array_length(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => array_length_dispatch::(args), + DataType::LargeList(_) => array_length_dispatch::(args), + _ => internal_err!( + "array_length does not support type '{:?}'", + args[0].data_type() + ), + } +} + /// Array_dims SQL function pub fn array_dims(args: &[ArrayRef]) -> Result { let list_array = as_list_array(&args[0])?; diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 092bc697a197..6ec2b2cb013b 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2371,24 +2371,44 @@ select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3) ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length scalar function #2 query III select array_length(make_array(1, 2, 3, 4, 5), 1), array_length(make_array(1, 2, 3), 1), array_length(make_array([1, 2], [3, 4], [5, 6]), 1); ---- 5 3 3 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 1), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 1); +---- +5 3 3 + # array_length scalar function #3 query III select array_length(make_array(1, 2, 3, 4, 5), 2), array_length(make_array(1, 2, 3), 2), array_length(make_array([1, 2], [3, 4], [5, 6]), 2); ---- NULL NULL 2 +query III +select array_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)'), 2), array_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))'), 2); +---- +NULL NULL 2 + # array_length scalar function #4 query II select array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 1), array_length(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 2); ---- 3 2 +query II +select array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 1), array_length(arrow_cast(array_repeat(array_repeat(array_repeat(3, 5), 2), 3), 'LargeList(List(List(Int64)))'), 2); +---- +3 2 + # array_length scalar function #5 query III select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) @@ -2407,6 +2427,11 @@ select list_length(make_array(1, 2, 3, 4, 5)), list_length(make_array(1, 2, 3)), ---- 5 3 3 NULL +query III +select list_length(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')), list_length(arrow_cast(make_array(1, 2, 3), 'LargeList(Int64)')), list_length(arrow_cast(make_array([1, 2], [3, 4], [5, 6]), 'LargeList(List(Int64))')); +---- +5 3 3 + # array_length with columns query I select array_length(column1, column3) from arrays_values; @@ -2420,6 +2445,18 @@ NULL NULL NULL +query I +select array_length(arrow_cast(column1, 'LargeList(Int64)'), column3) from arrays_values; +---- +10 +NULL +NULL +NULL +NULL +NULL +NULL +NULL + # array_length with columns and scalars query II select array_length(array[array[1, 2], array[3, 4]], column3), array_length(column1, 1) from arrays_values; @@ -2433,6 +2470,18 @@ NULL 10 NULL 10 NULL 10 +query II +select array_length(arrow_cast(array[array[1, 2], array[3, 4]], 'LargeList(List(Int64))'), column3), array_length(arrow_cast(column1, 'LargeList(Int64)'), 1) from arrays_values; +---- +2 10 +2 10 +NULL 10 +NULL 10 +NULL NULL +NULL 10 +NULL 10 +NULL 10 + ## array_dims (aliases: `list_dims`) # array dims error