diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fd899289ac82..cf0558f16818 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -591,8 +591,9 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] { List(field) => Ok(field.data_type().clone()), + LargeList(field) => Ok(field.data_type().clone()), _ => plan_err!( - "The {self} function can only accept list as the first argument" + "The {self} function can only accept list or largelist as the first argument" ), }, BuiltinScalarFunction::ArrayLength => Ok(UInt64), diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 7ccf58af832d..e7892a13c81c 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -369,18 +369,14 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result { } } -/// array_element SQL function -/// -/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. -/// `array_element(array, index)` -/// -/// For example: -/// > array_element(\[1, 2, 3], 2) -> 2 -pub fn array_element(args: &[ArrayRef]) -> Result { - let list_array = as_list_array(&args[0])?; - let indexes = as_int64_array(&args[1])?; - - let values = list_array.values(); +fn general_array_element( + array: &GenericListArray, + indexes: &Int64Array, +) -> Result +where + i64: TryInto, +{ + let values = array.values(); let original_data = values.to_data(); let capacity = Capacities::Array(original_data.len()); @@ -388,37 +384,47 @@ pub fn array_element(args: &[ArrayRef]) -> Result { let mut mutable = MutableArrayData::with_capacities(vec![&original_data], true, capacity); - fn adjusted_array_index(index: i64, len: usize) -> Option { + fn adjusted_array_index(index: i64, len: O) -> Result> + where + i64: TryInto, + { + let index: O = index.try_into().map_err(|_| { + DataFusionError::Execution(format!( + "array_element got invalid index: {}", + index + )) + })?; // 0 ~ len - 1 - let adjusted_zero_index = if index < 0 { - index + len as i64 + let adjusted_zero_index = if index < O::usize_as(0) { + index + len } else { - index - 1 + index - O::usize_as(1) }; - if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 { - Some(adjusted_zero_index) + if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len { + Ok(Some(adjusted_zero_index)) } else { // Out of bounds - None + Ok(None) } } - for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() { - let start = offset_window[0] as usize; - let end = offset_window[1] as usize; + for (row_index, offset_window) in array.offsets().windows(2).enumerate() { + let start = offset_window[0]; + let end = offset_window[1]; let len = end - start; // array is null - if len == 0 { + if len == O::usize_as(0) { mutable.extend_nulls(1); continue; } - let index = adjusted_array_index(indexes.value(row_index), len); + let index = adjusted_array_index::(indexes.value(row_index), len)?; if let Some(index) = index { - mutable.extend(0, start + index as usize, start + index as usize + 1); + let start = start.as_usize() + index.as_usize(); + mutable.extend(0, start, start + 1_usize); } else { // Index out of bounds mutable.extend_nulls(1); @@ -429,6 +435,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result { Ok(arrow_array::make_array(data)) } +/// array_element SQL function +/// +/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index. +/// `array_element(array, index)` +/// +/// For example: +/// > array_element(\[1, 2, 3], 2) -> 2 +pub fn array_element(args: &[ArrayRef]) -> Result { + match &args[0].data_type() { + DataType::List(_) => { + let array = as_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + DataType::LargeList(_) => { + let array = as_large_list_array(&args[0])?; + let indexes = as_int64_array(&args[1])?; + general_array_element::(array, indexes) + } + _ => not_impl_err!( + "array_element does not support type: {:?}", + args[0].data_type() + ), + } +} + fn general_except( l: &GenericListArray, r: &GenericListArray, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 210739aa51da..8540e6029853 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -719,7 +719,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument +query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument select array_element(1, 2); @@ -729,58 +729,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h' ---- 2 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element scalar function #2 (with positive index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11); +---- +NULL NULL + # array_element scalar function #3 (with zero) query IT select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0); +---- +NULL NULL + # array_element scalar function #4 (with NULL) query error select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL); +query error +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL); + # array_element scalar function #5 (with negative index) query IT select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3); ---- 4 l +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3); +---- +4 l + # array_element scalar function #6 (with negative index; out of bounds) query IT select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7); ---- NULL NULL +query IT +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7); +---- +NULL NULL + # array_element scalar function #7 (nested array) query ? select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1); ---- [1, 2, 3, 4, 5] +query ? +select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1); +---- +[1, 2, 3, 4, 5] + # array_extract scalar function #8 (function alias `array_slice`) query IT select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_element scalar function #9 (function alias `array_slice`) query IT select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # list_extract scalar function #10 (function alias `array_slice`) query IT select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3); ---- 2 l +query IT +select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3); +---- +2 l + # array_element with columns query I select array_element(column1, column2) from slices; @@ -793,6 +841,17 @@ NULL NULL 55 +query I +select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices; +---- +NULL +12 +NULL +37 +NULL +NULL +55 + # array_element with columns and scalars query II select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices; @@ -805,6 +864,17 @@ NULL 23 NULL 43 5 NULL +query II +select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices; +---- +1 3 +2 13 +NULL 23 +2 33 +4 NULL +NULL 43 +5 NULL + ## array_pop_back (aliases: `list_pop_back`) # array_pop_back scalar function #1