Skip to content

Commit

Permalink
fix concat lists of dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
rluvaton committed Dec 17, 2024
1 parent 34ab0ab commit 66825fb
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,31 +135,43 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let mut list_has_nulls = false;

let lists = arrays
.iter()
.map(|x| x.as_list::<OffsetSize>())
.inspect(|l| {
output_len += l.len();
list_has_nulls |= l.null_count() != 0;
})
.collect::<Vec<_>>();

let mut dictionary_output_len = 0;
let dictionaries: Vec<_> = lists
.iter()
.map(|x| x.values().as_ref().as_dictionary::<K>())
// TODO?
.inspect(|d| output_len += d.len())
.inspect(|d| dictionary_output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
if !should_merge_dictionary_values::<K>(&dictionaries, dictionary_output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
}

let merged = merge_dictionary_values(&dictionaries, None)?;

let lists_nulls = lists
.iter()
.fold(None, |acc, a| NullBuffer::union(acc.as_ref(), a.nulls()));

let lists_nulls = list_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for l in &lists {
match l.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(l.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

// Recompute keys
let mut key_values = Vec::with_capacity(output_len);
let mut key_values = Vec::with_capacity(dictionary_output_len);

let mut dictionary_has_nulls = false;
for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
Expand All @@ -171,7 +183,7 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe
}

let dictionary_nulls = dictionary_has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
let mut nulls = BooleanBufferBuilder::new(dictionary_output_len);
for d in &dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
Expand All @@ -183,7 +195,7 @@ fn concat_list_of_dictionaries<OffsetSize: OffsetSizeTrait, K: ArrowDictionaryKe

let keys = PrimitiveArray::<K>::new(key_values.into(), dictionary_nulls);
// Sanity check
assert_eq!(keys.len(), output_len);
assert_eq!(keys.len(), dictionary_output_len);

let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };

Expand Down Expand Up @@ -1146,7 +1158,6 @@ mod tests {
builder.finish()
}

// TODO - use already exists helper or make it use this one
fn assert_dictionary_has_unique_values<'a, K, V: 'static>(
array: &'a DictionaryArray<K>,
) where
Expand Down

0 comments on commit 66825fb

Please sign in to comment.