Skip to content

Commit

Permalink
feat(arrow-select): make list of dictionary merge dictionary keys
Browse files Browse the repository at this point in the history
TODO:
- [ ] Add support to nested lists
- [ ] Add more tests
- [ ] Fix failing test
  • Loading branch information
rluvaton committed Dec 17, 2024
1 parent fcf4aa4 commit 34ab0ab
Showing 1 changed file with 323 additions and 2 deletions.
325 changes: 323 additions & 2 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@ use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, Buffer, NullBuffer, OffsetBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, SchemaRef};
use num::Saturating;
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Expand Down Expand Up @@ -129,12 +131,149 @@ fn concat_dictionaries<K: ArrowDictionaryKeyType>(
Ok(Arc::new(array))
}

fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKeyType>(
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let lists = arrays
.iter()
.map(|x| x.as_list::<OffsetSize>())
.collect::<Vec<_>>();

let dictionaries: Vec<_> = lists
.iter()
.map(|x| x.values().as_ref().as_dictionary::<K>())
// TODO?
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
}

let merged = merge_dictionary_values(&dictionaries, None)?;

let lists_nulls = lists
.iter()
.fold(None, |acc, a| NullBuffer::union(acc.as_ref(), a.nulls()));


// Recompute keys
let mut key_values = Vec::with_capacity(output_len);

let mut dictionary_has_nulls = false;
for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
dictionary_has_nulls |= d.null_count() != 0;
for key in d.keys().values() {
// Use get to safely handle nulls
key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
}
}

let dictionary_nulls = dictionary_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for d in &dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(d.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let keys = PrimitiveArray::<K>::new(key_values.into(), dictionary_nulls);
// Sanity check
assert_eq!(keys.len(), output_len);

let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };

// Merge value offsets from the lists
let all_value_offsets_iterator = lists
.iter()
.map(|x| x.offsets());

let value_offset_buffer = merge_value_offsets(all_value_offsets_iterator);

let builder = ArrayDataBuilder::new(arrays[0].data_type().clone())
.len(output_len)
.nulls(lists_nulls)
// `GenericListArray` must only have 1 buffer
.buffers(vec![value_offset_buffer])
// `GenericListArray` must only have 1 child_data
.child_data(vec![array.to_data()]);

// TODO - maybe use build_unchecked?
let array_data = builder.build()?;

let array = GenericListArray::<OffsetSize>::from(array_data);
Ok(Arc::new(array))
}

/// Merge value offsets
///
///
/// if we have the following
/// [[0, 3, 5], [0, 2, 2, 8], [], [0, 0, 1]]
/// The output should be
/// [ 0, 3, 5, 7, 7, 13, 13, 14]
fn merge_value_offsets<'a, OffsetSize: OffsetSizeTrait, I: Iterator<Item = &'a OffsetBuffer<OffsetSize>>>(offset_buffers_iterator: I) -> Buffer {
// 1. Filter out empty lists
let mut offset_buffers_iterator = offset_buffers_iterator.filter(|x| !x.is_empty());

// 2. Get first non-empty list as the starting point
let starting_buffer = offset_buffers_iterator.next();

// 3. If we have only empty lists, return an empty buffer
if starting_buffer.is_none() {
return Buffer::from(&[])
}

let starting_buffer = starting_buffer.unwrap();

let mut offsets_iter: Box<dyn Iterator<Item=OffsetSize>> = Box::new(starting_buffer.iter().copied());

// 4. Get the last value in the starting buffer as the starting point for the next buffer
// Safety: We already filtered out empty lists
let mut advance_by = *starting_buffer.last().unwrap();

// 5. Iterate over the remaining buffers
for offset_buffer in offset_buffers_iterator {
// 6. Get the last value of the current buffer so we can know how much to advance the next buffer
// Safety: We already filtered out empty lists
let last_value = *offset_buffer.last().unwrap();

// 7. Advance the offset buffer by the last value in the previous buffer
let offset_buffer_iter = offset_buffer
.iter()
// Skip the first value as it is the initial offset of 0
.skip(1)
.map(move |&x| x + advance_by);

// 8. concat the current buffer with the previous buffer
// Chaining keeps the iterator have trusting length
offsets_iter = Box::new(offsets_iter.chain(offset_buffer_iter));

// 9. Update the next advance_by
advance_by += last_value;
}

unsafe {
Buffer::from_trusted_len_iter(offsets_iter)
}
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
};
}

macro_rules! list_dict_helper {
($t:ty, $o: ty, $arrays:expr) => {
return Ok(Arc::new(concat_list_of_dictionaries::<$o, $t>($arrays)?) as _)
};
}

fn get_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capacities {
match data_type {
DataType::Utf8 => binary_capacity::<Utf8Type>(arrays),
Expand Down Expand Up @@ -169,6 +308,21 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
_ => unreachable!("illegal dictionary key type {k}")
};
} else {
if let DataType::List(field) = d {
if let DataType::Dictionary(k, _) = field.data_type() {
return downcast_integer! {
k.as_ref() => (list_dict_helper, i32, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
}
} else if let DataType::LargeList(field) = d {
if let DataType::Dictionary(k, _) = field.data_type() {
return downcast_integer! {
k.as_ref() => (list_dict_helper, i64, arrays),
_ => unreachable!("illegal dictionary key type {k}")
};
}
}
let capacity = get_capacity(arrays, d);
concat_fallback(arrays, capacity)
}
Expand Down Expand Up @@ -228,8 +382,9 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder};
use arrow_schema::{Field, Schema};
use std::fmt::Debug;

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -851,4 +1006,170 @@ mod tests {
assert_eq!(array.null_count(), 10);
assert_eq!(array.logical_null_count(), 10);
}

#[test]
fn concat_dictionary_list_array_simple() {
let scalars = vec![
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("a")]),
create_single_row_list_of_dict(vec![Some("b")]),
];

let arrays = scalars.iter().map(|a| a as &(dyn Array)).collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
Some(vec![Some("a")]),
Some(vec![Some("b")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

let dict = list
.values()
.as_dictionary::<Int32Type>()
.downcast_dict::<StringArray>()
.unwrap();
println!("{:?}", dict);

assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

#[test]
fn concat_dictionary_list_array_with_multiple_rows() {
let scalars = vec![
create_list_of_dict(vec![
// Row 1
Some(vec![Some("a"), Some("c")]),
// Row 2
None,
// Row 3
Some(vec![Some("f"), Some("g"), None]),
// Row 4
Some(vec![Some("c"), Some("f")]),
]),
create_list_of_dict(vec![
// Row 1
Some(vec![Some("a")]),
// Row 2
Some(vec![]),
// Row 3
Some(vec![None, Some("b")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]),
create_list_of_dict(vec![
// Row 1
Some(vec![Some("g")]),
// Row 2
Some(vec![Some("h"), Some("i")]),
// Row 3
Some(vec![Some("j"), Some("a")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]),
];
let arrays = scalars
.iter()
.map(|a| a as &(dyn Array))
.collect::<Vec<_>>();
let concat_res = concat(arrays.as_slice()).unwrap();

let expected_list = create_list_of_dict(vec![
// First list:

// Row 1
Some(vec![Some("a"), Some("c")]),
// Row 2
None,
// Row 3
Some(vec![Some("f"), Some("g"), None]),
// Row 4
Some(vec![Some("c"), Some("f")]),
// Second list:
// Row 1
Some(vec![Some("a")]),
// Row 2
Some(vec![]),
// Row 3
Some(vec![None, Some("b")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
// Third list:

// Row 1
Some(vec![Some("g")]),
// Row 2
Some(vec![Some("h"), Some("i")]),
// Row 3
Some(vec![Some("j"), Some("a")]),
// Row 4
Some(vec![Some("d"), Some("e")]),
]);

let list = concat_res.as_list::<i32>();

// Assert that the list is equal to the expected list
list.iter().zip(expected_list.iter()).for_each(|(a, b)| {
assert_eq!(a, b);
});

// Assert that the
assert_dictionary_has_unique_values::<_, StringArray>(
list.values().as_dictionary::<Int32Type>(),
);
}

fn create_single_row_list_of_dict(list_items: Vec<Option<&'static str>>) -> GenericListArray<i32> {
let rows = list_items.into_iter().map(|row| Some(row)).collect();

create_list_of_dict(vec![rows])
}

fn create_list_of_dict(rows: Vec<Option<Vec<Option<&'static str>>>>) -> GenericListArray<i32> {
let mut builder =
GenericListBuilder::<i32, _>::new(StringDictionaryBuilder::<Int32Type>::new());

for row in rows {
builder.append_option(row);
}

builder.finish()
}

// TODO - use already exists helper or make it use this one
fn assert_dictionary_has_unique_values<'a, K, V: 'static>(
array: &'a DictionaryArray<K>,
) where
K: ArrowDictionaryKeyType,
V: Sync + Send,
&'a V: ArrayAccessor + IntoIterator,

<&'a V as ArrayAccessor>::Item: Default + Clone + PartialEq + Debug + Ord,
<&'a V as IntoIterator>::Item: Clone + PartialEq + Debug + Ord,
{
let dict = array.downcast_dict::<V>().unwrap();
let mut values = dict.values().clone().into_iter().collect::<Vec<_>>();

// remove duplicates must be sorted first so we can compare
values.sort();

let mut unique_values = values.clone();

unique_values.dedup();

assert_eq!(
values, unique_values,
"There are duplicates in the value list (the value list here is sorted which is only for the assertion)"
);
}
}

0 comments on commit 34ab0ab

Please sign in to comment.