From 8eff10f20cc7ed310528ab4c6e0f5ece69da9164 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Mon, 8 Apr 2024 13:08:36 +0200 Subject: [PATCH] Change `into_parts` output to better match `try_new` --- arrow-array/src/array/union_array.rs | 96 +++++++++++++++------------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index bda1ef67010f..0a85fa6c3917 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -23,6 +23,7 @@ use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode}; /// Contains the `UnionArray` type. /// use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; /// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout) @@ -325,13 +326,25 @@ impl UnionArray { /// # Example /// /// ``` + /// # use arrow_array::array::UnionArray; /// # use arrow_array::types::Int32Type; /// # use arrow_array::builder::UnionBuilder; + /// # use arrow_buffer::ScalarBuffer; /// # fn main() -> Result<(), arrow_schema::ArrowError> { /// let mut builder = UnionBuilder::new_dense(); /// builder.append::("a", 1).unwrap(); /// let union_array = builder.build()?; - /// let (union_fields, union_mode, type_ids, offsets, fields) = union_array.into_parts(); + /// + /// // Deconstruct into parts + /// let (union_mode, type_ids, offsets, field_type_ids, fields) = union_array.into_parts(); + /// + /// // Reconstruct from parts + /// let union_array = UnionArray::try_new( + /// &field_type_ids, + /// type_ids.into_inner(), + /// offsets.map(ScalarBuffer::into_inner), + /// fields, + /// ); /// # Ok(()) /// # } /// ``` @@ -339,11 +352,11 @@ impl UnionArray { pub fn into_parts( self, ) -> ( - UnionFields, UnionMode, ScalarBuffer, Option>, - Vec>, + Vec, + Vec<(Field, ArrayRef)>, ) { let Self { data_type, @@ -353,7 +366,21 @@ impl UnionArray { } = self; match data_type { DataType::Union(union_fields, union_mode) => { - (union_fields, union_mode, type_ids, offsets, fields) + let union_fields = union_fields.iter().collect::>(); + let (field_type_ids, fields) = fields + .into_iter() + .enumerate() + .flat_map(|(type_id, array_ref)| { + array_ref.map(|array_ref| { + let type_id = type_id as i8; + ( + type_id, + ((*Arc::clone(union_fields[&type_id])).clone(), array_ref), + ) + }) + }) + .unzip(); + (union_mode, type_ids, offsets, field_type_ids, fields) } _ => unreachable!(), } @@ -1254,27 +1281,26 @@ mod tests { Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int8, false), ]; - let field_type_ids = [0, 1]; - let (union_fields, union_mode, type_ids, offsets, fields) = dense_union.into_parts(); + let (union_mode, type_ids, offsets, field_type_ids, fields) = dense_union.into_parts(); + assert_eq!(union_mode, UnionMode::Dense); + assert_eq!(field_type_ids, [0, 1]); assert_eq!( - union_fields, - UnionFields::new(field_type_ids, field.clone()) + field.to_vec(), + fields + .iter() + .cloned() + .map(|(field, _)| field) + .collect::>() ); - assert_eq!(union_mode, UnionMode::Dense); assert_eq!(type_ids, [0, 1, 0]); assert!(offsets.is_some()); assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]); - assert_eq!(fields.len(), 2); let result = UnionArray::try_new( - &[0, 1], + &field_type_ids, type_ids.into_inner(), offsets.map(ScalarBuffer::into_inner), - field - .clone() - .into_iter() - .zip(fields.into_iter().flatten()) - .collect(), + fields, ); assert!(result.is_ok()); assert_eq!(result.unwrap().len(), 3); @@ -1285,24 +1311,16 @@ mod tests { builder.append::("a", 3).unwrap(); let sparse_union = builder.build().unwrap(); - let (union_fields, union_mode, type_ids, offsets, fields) = sparse_union.into_parts(); - assert_eq!( - union_fields, - UnionFields::new(field_type_ids, field.clone()) - ); + let (union_mode, type_ids, offsets, field_type_ids, fields) = sparse_union.into_parts(); assert_eq!(union_mode, UnionMode::Sparse); assert_eq!(type_ids, [0, 1, 0]); assert!(offsets.is_none()); - assert_eq!(fields.len(), 2); let result = UnionArray::try_new( - &[0, 1], + &field_type_ids, type_ids.into_inner(), offsets.map(ScalarBuffer::into_inner), - field - .into_iter() - .zip(fields.into_iter().flatten()) - .collect(), + fields, ); assert!(result.is_ok()); assert_eq!(result.unwrap().len(), 3); @@ -1310,10 +1328,10 @@ mod tests { #[test] fn into_parts_custom_type_ids() { - const TYPE_IDS: [i8; 3] = [8, 4, 9]; + let mut set_field_type_ids: [i8; 3] = [8, 4, 9]; let data_type = DataType::Union( UnionFields::new( - TYPE_IDS, + set_field_type_ids, [ Field::new("strings", DataType::Utf8, false), Field::new("integers", DataType::Int32, false), @@ -1339,28 +1357,18 @@ mod tests { .unwrap(); let array = UnionArray::from(data); - let (union_fields, union_mode, type_ids, offsets, mut fields) = array.into_parts(); + let (union_mode, type_ids, offsets, field_type_ids, fields) = array.into_parts(); assert_eq!(union_mode, UnionMode::Dense); + set_field_type_ids.sort(); + assert_eq!(field_type_ids, set_field_type_ids); let result = UnionArray::try_new( - &TYPE_IDS, + &field_type_ids, type_ids.into_inner(), offsets.map(ScalarBuffer::into_inner), - union_fields - .iter() - .map(|(type_id, field)| { - ( - (*Arc::clone(field)).clone(), - fields[type_id as usize].take().unwrap(), - ) - }) - .collect(), + fields, ); assert!(result.is_ok()); let array = result.unwrap(); assert_eq!(array.len(), 7); - let (_, _, _, _, fields) = array.into_parts(); - for type_id in TYPE_IDS { - assert!(fields.get(type_id as usize).is_some_and(Option::is_some)) - } } }