Skip to content

Commit

Permalink
support LargeList in array_element (#8570)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H authored Dec 18, 2023
1 parent d220bf4 commit d33ca4d
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 27 deletions.
3 changes: 2 additions & 1 deletion datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,9 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
List(field) => Ok(field.data_type().clone()),
LargeList(field) => Ok(field.data_type().clone()),
_ => plan_err!(
"The {self} function can only accept list as the first argument"
"The {self} function can only accept list or largelist as the first argument"
),
},
BuiltinScalarFunction::ArrayLength => Ok(UInt64),
Expand Down
82 changes: 57 additions & 25 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,56 +370,62 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
}
}

/// array_element SQL function
///
/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
/// `array_element(array, index)`
///
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;

let values = list_array.values();
fn general_array_element<O: OffsetSizeTrait>(
array: &GenericListArray<O>,
indexes: &Int64Array,
) -> Result<ArrayRef>
where
i64: TryInto<O>,
{
let values = array.values();
let original_data = values.to_data();
let capacity = Capacities::Array(original_data.len());

// use_nulls: true, we don't construct List for array_element, so we need explicit nulls.
let mut mutable =
MutableArrayData::with_capacities(vec![&original_data], true, capacity);

fn adjusted_array_index(index: i64, len: usize) -> Option<i64> {
fn adjusted_array_index<O: OffsetSizeTrait>(index: i64, len: O) -> Result<Option<O>>
where
i64: TryInto<O>,
{
let index: O = index.try_into().map_err(|_| {
DataFusionError::Execution(format!(
"array_element got invalid index: {}",
index
))
})?;
// 0 ~ len - 1
let adjusted_zero_index = if index < 0 {
index + len as i64
let adjusted_zero_index = if index < O::usize_as(0) {
index + len
} else {
index - 1
index - O::usize_as(1)
};

if 0 <= adjusted_zero_index && adjusted_zero_index < len as i64 {
Some(adjusted_zero_index)
if O::usize_as(0) <= adjusted_zero_index && adjusted_zero_index < len {
Ok(Some(adjusted_zero_index))
} else {
// Out of bounds
None
Ok(None)
}
}

for (row_index, offset_window) in list_array.offsets().windows(2).enumerate() {
let start = offset_window[0] as usize;
let end = offset_window[1] as usize;
for (row_index, offset_window) in array.offsets().windows(2).enumerate() {
let start = offset_window[0];
let end = offset_window[1];
let len = end - start;

// array is null
if len == 0 {
if len == O::usize_as(0) {
mutable.extend_nulls(1);
continue;
}

let index = adjusted_array_index(indexes.value(row_index), len);
let index = adjusted_array_index::<O>(indexes.value(row_index), len)?;

if let Some(index) = index {
mutable.extend(0, start + index as usize, start + index as usize + 1);
let start = start.as_usize() + index.as_usize();
mutable.extend(0, start, start + 1_usize);
} else {
// Index out of bounds
mutable.extend_nulls(1);
Expand All @@ -430,6 +436,32 @@ pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(arrow_array::make_array(data))
}

/// array_element SQL function
///
/// There are two arguments for array_element, the first one is the array, the second one is the 1-indexed index.
/// `array_element(array, index)`
///
/// For example:
/// > array_element(\[1, 2, 3], 2) -> 2
pub fn array_element(args: &[ArrayRef]) -> Result<ArrayRef> {
match &args[0].data_type() {
DataType::List(_) => {
let array = as_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;
general_array_element::<i32>(array, indexes)
}
DataType::LargeList(_) => {
let array = as_large_list_array(&args[0])?;
let indexes = as_int64_array(&args[1])?;
general_array_element::<i64>(array, indexes)
}
_ => not_impl_err!(
"array_element does not support type: {:?}",
args[0].data_type()
),
}
}

fn general_except<OffsetSize: OffsetSizeTrait>(
l: &GenericListArray<OffsetSize>,
r: &GenericListArray<OffsetSize>,
Expand Down
72 changes: 71 additions & 1 deletion datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ from arrays_values_without_nulls;
## array_element (aliases: array_extract, list_extract, list_element)

# array_element error
query error DataFusion error: Error during planning: The array_element function can only accept list as the first argument
query error DataFusion error: Error during planning: The array_element function can only accept list or largelist as the first argument
select array_element(1, 2);


Expand All @@ -727,58 +727,106 @@ select array_element(make_array(1, 2, 3, 4, 5), 2), array_element(make_array('h'
----
2 l

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# array_element scalar function #2 (with positive index; out of bounds)
query IT
select array_element(make_array(1, 2, 3, 4, 5), 7), array_element(make_array('h', 'e', 'l', 'l', 'o'), 11);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 7), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 11);
----
NULL NULL

# array_element scalar function #3 (with zero)
query IT
select array_element(make_array(1, 2, 3, 4, 5), 0), array_element(make_array('h', 'e', 'l', 'l', 'o'), 0);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 0), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 0);
----
NULL NULL

# array_element scalar function #4 (with NULL)
query error
select array_element(make_array(1, 2, 3, 4, 5), NULL), array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);

query error
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), NULL), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), NULL);

# array_element scalar function #5 (with negative index)
query IT
select array_element(make_array(1, 2, 3, 4, 5), -2), array_element(make_array('h', 'e', 'l', 'l', 'o'), -3);
----
4 l

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -2), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -3);
----
4 l

# array_element scalar function #6 (with negative index; out of bounds)
query IT
select array_element(make_array(1, 2, 3, 4, 5), -11), array_element(make_array('h', 'e', 'l', 'l', 'o'), -7);
----
NULL NULL

query IT
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -11), array_element(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7);
----
NULL NULL

# array_element scalar function #7 (nested array)
query ?
select array_element(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 1);
----
[1, 2, 3, 4, 5]

query ?
select array_element(arrow_cast(make_array(make_array(1, 2, 3, 4, 5), make_array(6, 7, 8, 9, 10)), 'LargeList(List(Int64))'), 1);
----
[1, 2, 3, 4, 5]

# array_extract scalar function #8 (function alias `array_slice`)
query IT
select array_extract(make_array(1, 2, 3, 4, 5), 2), array_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select array_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# list_element scalar function #9 (function alias `array_slice`)
query IT
select list_element(make_array(1, 2, 3, 4, 5), 2), list_element(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select list_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# list_extract scalar function #10 (function alias `array_slice`)
query IT
select list_extract(make_array(1, 2, 3, 4, 5), 2), list_extract(make_array('h', 'e', 'l', 'l', 'o'), 3);
----
2 l

query IT
select list_extract(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2), array_extract(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 3);
----
2 l

# array_element with columns
query I
select array_element(column1, column2) from slices;
Expand All @@ -791,6 +839,17 @@ NULL
NULL
55

query I
select array_element(arrow_cast(column1, 'LargeList(Int64)'), column2) from slices;
----
NULL
12
NULL
37
NULL
NULL
55

# array_element with columns and scalars
query II
select array_element(make_array(1, 2, 3, 4, 5), column2), array_element(column1, 3) from slices;
Expand All @@ -803,6 +862,17 @@ NULL 23
NULL 43
5 NULL

query II
select array_element(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_element(arrow_cast(column1, 'LargeList(Int64)'), 3) from slices;
----
1 3
2 13
NULL 23
2 33
4 NULL
NULL 43
5 NULL

## array_pop_back (aliases: `list_pop_back`)

# array_pop_back scalar function #1
Expand Down

0 comments on commit d33ca4d

Please sign in to comment.