From c203785ca398b879960bffbd30b988c9728b7c23 Mon Sep 17 00:00:00 2001 From: Matthijs Brobbel Date: Tue, 9 Apr 2024 11:42:44 +0200 Subject: [PATCH] Add `UnionArray::into_parts` (#5585) * Add `UnionArray::into_parts` * Return `UnionFields` and `UnionMode` instead of `DataType` * Add `into_parts` test with custom type ids * Change `into_parts` output to better match `try_new` * Remove UnionMode --------- Co-authored-by: Raphael Taylor-Davies --- arrow-array/src/array/union_array.rs | 166 +++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index e3e637247537..22d4cf90a092 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -23,6 +23,7 @@ use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode}; /// Contains the `UnionArray` type. /// use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; /// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout) @@ -319,6 +320,70 @@ impl UnionArray { fields, } } + + /// Deconstruct this array into its constituent parts + /// + /// # Example + /// + /// ``` + /// # use arrow_array::array::UnionArray; + /// # use arrow_array::types::Int32Type; + /// # use arrow_array::builder::UnionBuilder; + /// # use arrow_buffer::ScalarBuffer; + /// # fn main() -> Result<(), arrow_schema::ArrowError> { + /// let mut builder = UnionBuilder::new_dense(); + /// builder.append::("a", 1).unwrap(); + /// let union_array = builder.build()?; + /// + /// // Deconstruct into parts + /// let (type_ids, offsets, field_type_ids, fields) = union_array.into_parts(); + /// + /// // Reconstruct from parts + /// let union_array = UnionArray::try_new( + /// &field_type_ids, + /// type_ids.into_inner(), + /// offsets.map(ScalarBuffer::into_inner), + /// fields, + /// ); + /// # Ok(()) + /// # } + /// ``` + #[allow(clippy::type_complexity)] + pub fn into_parts( + self, + ) -> ( + ScalarBuffer, + Option>, + Vec, + Vec<(Field, ArrayRef)>, + ) { + let Self { + data_type, + type_ids, + offsets, + fields, + } = self; + match data_type { + DataType::Union(union_fields, _) => { + let union_fields = union_fields.iter().collect::>(); + let (field_type_ids, fields) = fields + .into_iter() + .enumerate() + .flat_map(|(type_id, array_ref)| { + array_ref.map(|array_ref| { + let type_id = type_id as i8; + ( + type_id, + ((*Arc::clone(union_fields[&type_id])).clone(), array_ref), + ) + }) + }) + .unzip(); + (type_ids, offsets, field_type_ids, fields) + } + _ => unreachable!(), + } + } } impl From for UnionArray { @@ -505,6 +570,7 @@ impl std::fmt::Debug for UnionArray { mod tests { use super::*; + use crate::array::Int8Type; use crate::builder::UnionBuilder; use crate::cast::AsArray; use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type}; @@ -1201,4 +1267,104 @@ mod tests { assert_eq!(v.len(), 1); assert_eq!(v.as_string::().value(0), "baz"); } + + #[test] + fn into_parts() { + let mut builder = UnionBuilder::new_dense(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("a", 3).unwrap(); + let dense_union = builder.build().unwrap(); + + let field = [ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int8, false), + ]; + let (type_ids, offsets, field_type_ids, fields) = dense_union.into_parts(); + assert_eq!(field_type_ids, [0, 1]); + assert_eq!( + field.to_vec(), + fields + .iter() + .cloned() + .map(|(field, _)| field) + .collect::>() + ); + assert_eq!(type_ids, [0, 1, 0]); + assert!(offsets.is_some()); + assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]); + + let result = UnionArray::try_new( + &field_type_ids, + type_ids.into_inner(), + offsets.map(ScalarBuffer::into_inner), + fields, + ); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 3); + + let mut builder = UnionBuilder::new_sparse(); + builder.append::("a", 1).unwrap(); + builder.append::("b", 2).unwrap(); + builder.append::("a", 3).unwrap(); + let sparse_union = builder.build().unwrap(); + + let (type_ids, offsets, field_type_ids, fields) = sparse_union.into_parts(); + assert_eq!(type_ids, [0, 1, 0]); + assert!(offsets.is_none()); + + let result = UnionArray::try_new( + &field_type_ids, + type_ids.into_inner(), + offsets.map(ScalarBuffer::into_inner), + fields, + ); + assert!(result.is_ok()); + assert_eq!(result.unwrap().len(), 3); + } + + #[test] + fn into_parts_custom_type_ids() { + let mut set_field_type_ids: [i8; 3] = [8, 4, 9]; + let data_type = DataType::Union( + UnionFields::new( + set_field_type_ids, + [ + Field::new("strings", DataType::Utf8, false), + Field::new("integers", DataType::Int32, false), + Field::new("floats", DataType::Float64, false), + ], + ), + UnionMode::Dense, + ); + let string_array = StringArray::from(vec!["foo", "bar", "baz"]); + let int_array = Int32Array::from(vec![5, 6, 4]); + let float_array = Float64Array::from(vec![10.0]); + let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); + let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); + let data = ArrayData::builder(data_type) + .len(7) + .buffers(vec![type_ids, value_offsets]) + .child_data(vec![ + string_array.into_data(), + int_array.into_data(), + float_array.into_data(), + ]) + .build() + .unwrap(); + let array = UnionArray::from(data); + + let (type_ids, offsets, field_type_ids, fields) = array.into_parts(); + set_field_type_ids.sort(); + assert_eq!(field_type_ids, set_field_type_ids); + let result = UnionArray::try_new( + &field_type_ids, + type_ids.into_inner(), + offsets.map(ScalarBuffer::into_inner), + fields, + ); + assert!(result.is_ok()); + let array = result.unwrap(); + assert_eq!(array.len(), 7); + } }