diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 43a707aa197..2ae060ebc93 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -135,31 +135,43 @@ fn concat_list_of_dictionaries Result { let mut output_len = 0; + let mut list_has_nulls = false; + let lists = arrays .iter() .map(|x| x.as_list::()) + .inspect(|l| { + output_len += l.len(); + list_has_nulls |= l.null_count() != 0; + }) .collect::>(); + let mut dictionary_output_len = 0; let dictionaries: Vec<_> = lists .iter() .map(|x| x.values().as_ref().as_dictionary::()) - // TODO? - .inspect(|d| output_len += d.len()) + .inspect(|d| dictionary_output_len += d.len()) .collect(); - if !should_merge_dictionary_values::(&dictionaries, output_len) { + if !should_merge_dictionary_values::(&dictionaries, dictionary_output_len) { return concat_fallback(arrays, Capacities::Array(output_len)); } let merged = merge_dictionary_values(&dictionaries, None)?; - let lists_nulls = lists - .iter() - .fold(None, |acc, a| NullBuffer::union(acc.as_ref(), a.nulls())); - + let lists_nulls = list_has_nulls.then(|| { + let mut nulls = BooleanBufferBuilder::new(output_len); + for l in &lists { + match l.nulls() { + Some(n) => nulls.append_buffer(n.inner()), + None => nulls.append_n(l.len(), true), + } + } + NullBuffer::new(nulls.finish()) + }); // Recompute keys - let mut key_values = Vec::with_capacity(output_len); + let mut key_values = Vec::with_capacity(dictionary_output_len); let mut dictionary_has_nulls = false; for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) { @@ -171,7 +183,7 @@ fn concat_list_of_dictionaries nulls.append_buffer(n.inner()), @@ -183,7 +195,7 @@ fn concat_list_of_dictionaries::new(key_values.into(), dictionary_nulls); // Sanity check - assert_eq!(keys.len(), output_len); + assert_eq!(keys.len(), dictionary_output_len); let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) }; @@ -1146,7 +1158,6 @@ mod tests { builder.finish() } - // TODO - use already exists helper or make it use this one fn assert_dictionary_has_unique_values<'a, K, V: 'static>( array: &'a DictionaryArray, ) where