diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 129b90ee047..43a707aa197 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -34,9 +34,11 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values} use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer}; +use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, Buffer, NullBuffer, OffsetBuffer}; use arrow_data::transform::{Capacities, MutableArrayData}; +use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, SchemaRef}; +use num::Saturating; use std::sync::Arc; fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { @@ -129,12 +131,149 @@ fn concat_dictionaries( Ok(Arc::new(array)) } +fn concat_list_of_dictionaries( + arrays: &[&dyn Array], +) -> Result { + let mut output_len = 0; + let lists = arrays + .iter() + .map(|x| x.as_list::()) + .collect::>(); + + let dictionaries: Vec<_> = lists + .iter() + .map(|x| x.values().as_ref().as_dictionary::()) + // TODO? + .inspect(|d| output_len += d.len()) + .collect(); + + if !should_merge_dictionary_values::(&dictionaries, output_len) { + return concat_fallback(arrays, Capacities::Array(output_len)); + } + + let merged = merge_dictionary_values(&dictionaries, None)?; + + let lists_nulls = lists + .iter() + .fold(None, |acc, a| NullBuffer::union(acc.as_ref(), a.nulls())); + + + // Recompute keys + let mut key_values = Vec::with_capacity(output_len); + + let mut dictionary_has_nulls = false; + for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) { + dictionary_has_nulls |= d.null_count() != 0; + for key in d.keys().values() { + // Use get to safely handle nulls + key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default()) + } + } + + let dictionary_nulls = dictionary_has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for d in &dictionaries { + match d.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(d.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); + + let keys = PrimitiveArray::::new(key_values.into(), dictionary_nulls); + // Sanity check + assert_eq!(keys.len(), output_len); + + let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) }; + + // Merge value offsets from the lists + let all_value_offsets_iterator = lists + .iter() + .map(|x| x.offsets()); + + let value_offset_buffer = merge_value_offsets(all_value_offsets_iterator); + + let builder = ArrayDataBuilder::new(arrays[0].data_type().clone()) + .len(output_len) + .nulls(lists_nulls) + // `GenericListArray` must only have 1 buffer + .buffers(vec![value_offset_buffer]) + // `GenericListArray` must only have 1 child_data + .child_data(vec![array.to_data()]); + + // TODO - maybe use build_unchecked? + let array_data = builder.build()?; + + let array = GenericListArray::::from(array_data); + Ok(Arc::new(array)) +} + +/// Merge value offsets +/// +/// +/// if we have the following +/// [[0, 3, 5], [0, 2, 2, 8], [], [0, 0, 1]] +/// The output should be +/// [ 0, 3, 5, 7, 7, 13, 13, 14] +fn merge_value_offsets<'a, OffsetSize: OffsetSizeTrait, I: Iterator>>(offset_buffers_iterator: I) -> Buffer { + // 1. Filter out empty lists + let mut offset_buffers_iterator = offset_buffers_iterator.filter(|x| !x.is_empty()); + + // 2. Get first non-empty list as the starting point + let starting_buffer = offset_buffers_iterator.next(); + + // 3. If we have only empty lists, return an empty buffer + if starting_buffer.is_none() { + return Buffer::from(&[]) + } + + let starting_buffer = starting_buffer.unwrap(); + + let mut offsets_iter: Box> = Box::new(starting_buffer.iter().copied()); + + // 4. Get the last value in the starting buffer as the starting point for the next buffer + // Safety: We already filtered out empty lists + let mut advance_by = *starting_buffer.last().unwrap(); + + // 5. Iterate over the remaining buffers + for offset_buffer in offset_buffers_iterator { + // 6. Get the last value of the current buffer so we can know how much to advance the next buffer + // Safety: We already filtered out empty lists + let last_value = *offset_buffer.last().unwrap(); + + // 7. Advance the offset buffer by the last value in the previous buffer + let offset_buffer_iter = offset_buffer + .iter() + // Skip the first value as it is the initial offset of 0 + .skip(1) + .map(move |&x| x + advance_by); + + // 8. concat the current buffer with the previous buffer + // Chaining keeps the iterator have trusting length + offsets_iter = Box::new(offsets_iter.chain(offset_buffer_iter)); + + // 9. Update the next advance_by + advance_by += last_value; + } + + unsafe { + Buffer::from_trusted_len_iter(offsets_iter) + } +} + macro_rules! dict_helper { ($t:ty, $arrays:expr) => { return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _) }; } +macro_rules! list_dict_helper { + ($t:ty, $o: ty, $arrays:expr) => { + return Ok(Arc::new(concat_list_of_dictionaries::<$o, $t>($arrays)?) as _) + }; +} + fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities { match data_type { DataType::Utf8 => binary_capacity::(arrays), @@ -169,6 +308,21 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { _ => unreachable!("illegal dictionary key type {k}") }; } else { + if let DataType::List(field) = d { + if let DataType::Dictionary(k, _) = field.data_type() { + return downcast_integer! { + k.as_ref() => (list_dict_helper, i32, arrays), + _ => unreachable!("illegal dictionary key type {k}") + }; + } + } else if let DataType::LargeList(field) = d { + if let DataType::Dictionary(k, _) = field.data_type() { + return downcast_integer! { + k.as_ref() => (list_dict_helper, i64, arrays), + _ => unreachable!("illegal dictionary key type {k}") + }; + } + } let capacity = get_capacity(arrays, d); concat_fallback(arrays, capacity) } @@ -228,8 +382,9 @@ pub fn concat_batches<'a>( #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::StringDictionaryBuilder; + use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; use arrow_schema::{Field, Schema}; + use std::fmt::Debug; #[test] fn test_concat_empty_vec() { @@ -851,4 +1006,170 @@ mod tests { assert_eq!(array.null_count(), 10); assert_eq!(array.logical_null_count(), 10); } + + #[test] + fn concat_dictionary_list_array_simple() { + let scalars = vec![ + create_single_row_list_of_dict(vec![Some("a")]), + create_single_row_list_of_dict(vec![Some("a")]), + create_single_row_list_of_dict(vec![Some("b")]), + ]; + + let arrays = scalars.iter().map(|a| a as &(dyn Array)).collect::>(); + let concat_res = concat(arrays.as_slice()).unwrap(); + + let expected_list = create_list_of_dict(vec![ + // Row 1 + Some(vec![Some("a")]), + Some(vec![Some("a")]), + Some(vec![Some("b")]), + ]); + + let list = concat_res.as_list::(); + + // Assert that the list is equal to the expected list + list.iter().zip(expected_list.iter()).for_each(|(a, b)| { + assert_eq!(a, b); + }); + + let dict = list + .values() + .as_dictionary::() + .downcast_dict::() + .unwrap(); + println!("{:?}", dict); + + assert_dictionary_has_unique_values::<_, StringArray>( + list.values().as_dictionary::(), + ); + } + + #[test] + fn concat_dictionary_list_array_with_multiple_rows() { + let scalars = vec![ + create_list_of_dict(vec![ + // Row 1 + Some(vec![Some("a"), Some("c")]), + // Row 2 + None, + // Row 3 + Some(vec![Some("f"), Some("g"), None]), + // Row 4 + Some(vec![Some("c"), Some("f")]), + ]), + create_list_of_dict(vec![ + // Row 1 + Some(vec![Some("a")]), + // Row 2 + Some(vec![]), + // Row 3 + Some(vec![None, Some("b")]), + // Row 4 + Some(vec![Some("d"), Some("e")]), + ]), + create_list_of_dict(vec![ + // Row 1 + Some(vec![Some("g")]), + // Row 2 + Some(vec![Some("h"), Some("i")]), + // Row 3 + Some(vec![Some("j"), Some("a")]), + // Row 4 + Some(vec![Some("d"), Some("e")]), + ]), + ]; + let arrays = scalars + .iter() + .map(|a| a as &(dyn Array)) + .collect::>(); + let concat_res = concat(arrays.as_slice()).unwrap(); + + let expected_list = create_list_of_dict(vec![ + // First list: + + // Row 1 + Some(vec![Some("a"), Some("c")]), + // Row 2 + None, + // Row 3 + Some(vec![Some("f"), Some("g"), None]), + // Row 4 + Some(vec![Some("c"), Some("f")]), + // Second list: + // Row 1 + Some(vec![Some("a")]), + // Row 2 + Some(vec![]), + // Row 3 + Some(vec![None, Some("b")]), + // Row 4 + Some(vec![Some("d"), Some("e")]), + // Third list: + + // Row 1 + Some(vec![Some("g")]), + // Row 2 + Some(vec![Some("h"), Some("i")]), + // Row 3 + Some(vec![Some("j"), Some("a")]), + // Row 4 + Some(vec![Some("d"), Some("e")]), + ]); + + let list = concat_res.as_list::(); + + // Assert that the list is equal to the expected list + list.iter().zip(expected_list.iter()).for_each(|(a, b)| { + assert_eq!(a, b); + }); + + // Assert that the + assert_dictionary_has_unique_values::<_, StringArray>( + list.values().as_dictionary::(), + ); + } + + fn create_single_row_list_of_dict(list_items: Vec>) -> GenericListArray { + let rows = list_items.into_iter().map(|row| Some(row)).collect(); + + create_list_of_dict(vec![rows]) + } + + fn create_list_of_dict(rows: Vec>>>) -> GenericListArray { + let mut builder = + GenericListBuilder::::new(StringDictionaryBuilder::::new()); + + for row in rows { + builder.append_option(row); + } + + builder.finish() + } + + // TODO - use already exists helper or make it use this one + fn assert_dictionary_has_unique_values<'a, K, V: 'static>( + array: &'a DictionaryArray, + ) where + K: ArrowDictionaryKeyType, + V: Sync + Send, + &'a V: ArrayAccessor + IntoIterator, + + <&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord, + <&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord, + { + let dict = array.downcast_dict::().unwrap(); + let mut values = dict.values().clone().into_iter().collect::>(); + + // remove duplicates must be sorted first so we can compare + values.sort(); + + let mut unique_values = values.clone(); + + unique_values.dedup(); + + assert_eq!( + values, unique_values, + "There are duplicates in the value list (the value list here is sorted which is only for the assertion)" + ); + } }