From 237292896a69f734d1c98e5e8b45eac9a37b49c2 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 18 Sep 2023 21:27:05 +0100 Subject: [PATCH] Fix merge_dictionary_values in selection kernels --- arrow-select/src/dictionary.rs | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 8630b332f068..330196ae33f4 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -152,7 +152,7 @@ pub fn merge_dictionary_values( ) -> Result, ArrowError> { let mut num_values = 0; - let mut values = Vec::with_capacity(dictionaries.len()); + let mut values_arrays = Vec::with_capacity(dictionaries.len()); let mut value_slices = Vec::with_capacity(dictionaries.len()); for (idx, dictionary) in dictionaries.iter().enumerate() { @@ -164,11 +164,13 @@ pub fn merge_dictionary_values( (None, None) => None, }; let keys = dictionary.keys().values(); - let values_mask = compute_values_mask(keys, key_mask.as_ref()); - let v = dictionary.values().as_ref(); - num_values += v.len(); - value_slices.push(get_masked_values(v, &values_mask)); - values.push(v) + let values = dictionary.values().as_ref(); + let values_mask = compute_values_mask(keys, key_mask.as_ref(), values.len()); + + let masked_values = get_masked_values(values, &values_mask); + num_values += masked_values.len(); + value_slices.push(masked_values); + values_arrays.push(values) } // Map from value to new index @@ -202,7 +204,7 @@ pub fn merge_dictionary_values( Ok(MergedDictionaries { key_mappings, - values: interleave(&values, &indices)?, + values: interleave(&values_arrays, &indices)?, }) } @@ -211,9 +213,10 @@ pub fn merge_dictionary_values( fn compute_values_mask( keys: &ScalarBuffer, mask: Option<&BooleanBuffer>, + max_key: usize, ) -> BooleanBuffer { - let mut builder = BooleanBufferBuilder::new(keys.len()); - builder.advance(keys.len()); + let mut builder = BooleanBufferBuilder::new(max_key); + builder.advance(max_key); match mask { Some(n) => n @@ -330,4 +333,15 @@ mod tests { assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]); assert_eq!(&merged.key_mappings[1], &[]); } + + #[test] + fn test_merge_keys_smaller() { + let values = StringArray::from_iter_values(["a", "b"]); + let keys = Int32Array::from_iter_values([1]); + let a = DictionaryArray::new(keys, Arc::new(values)); + + let merged = merge_dictionary_values(&[&a], None).unwrap(); + let expected = StringArray::from(vec!["b"]); + assert_eq!(merged.values.as_ref(), &expected); + } }