Skip to content

Commit

Permalink
Add UnionArray::into_parts (#5585)
Browse files Browse the repository at this point in the history
* Add `UnionArray::into_parts`

* Return `UnionFields` and `UnionMode` instead of `DataType`

* Add `into_parts` test with custom type ids

* Change `into_parts` output to better match `try_new`

* Remove UnionMode

---------

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
mbrobbel and tustvold authored Apr 9, 2024
1 parent 1b3d1a9 commit c203785
Showing 1 changed file with 166 additions and 0 deletions.
166 changes: 166 additions & 0 deletions arrow-array/src/array/union_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
Expand Down Expand Up @@ -319,6 +320,70 @@ impl UnionArray {
fields,
}
}

/// Deconstruct this array into its constituent parts
///
/// # Example
///
/// ```
/// # use arrow_array::array::UnionArray;
/// # use arrow_array::types::Int32Type;
/// # use arrow_array::builder::UnionBuilder;
/// # use arrow_buffer::ScalarBuffer;
/// # fn main() -> Result<(), arrow_schema::ArrowError> {
/// let mut builder = UnionBuilder::new_dense();
/// builder.append::<Int32Type>("a", 1).unwrap();
/// let union_array = builder.build()?;
///
/// // Deconstruct into parts
/// let (type_ids, offsets, field_type_ids, fields) = union_array.into_parts();
///
/// // Reconstruct from parts
/// let union_array = UnionArray::try_new(
/// &field_type_ids,
/// type_ids.into_inner(),
/// offsets.map(ScalarBuffer::into_inner),
/// fields,
/// );
/// # Ok(())
/// # }
/// ```
#[allow(clippy::type_complexity)]
pub fn into_parts(
self,
) -> (
ScalarBuffer<i8>,
Option<ScalarBuffer<i32>>,
Vec<i8>,
Vec<(Field, ArrayRef)>,
) {
let Self {
data_type,
type_ids,
offsets,
fields,
} = self;
match data_type {
DataType::Union(union_fields, _) => {
let union_fields = union_fields.iter().collect::<HashMap<_, _>>();
let (field_type_ids, fields) = fields
.into_iter()
.enumerate()
.flat_map(|(type_id, array_ref)| {
array_ref.map(|array_ref| {
let type_id = type_id as i8;
(
type_id,
((*Arc::clone(union_fields[&type_id])).clone(), array_ref),
)
})
})
.unzip();
(type_ids, offsets, field_type_ids, fields)
}
_ => unreachable!(),
}
}
}

impl From<ArrayData> for UnionArray {
Expand Down Expand Up @@ -505,6 +570,7 @@ impl std::fmt::Debug for UnionArray {
mod tests {
use super::*;

use crate::array::Int8Type;
use crate::builder::UnionBuilder;
use crate::cast::AsArray;
use crate::types::{Float32Type, Float64Type, Int32Type, Int64Type};
Expand Down Expand Up @@ -1201,4 +1267,104 @@ mod tests {
assert_eq!(v.len(), 1);
assert_eq!(v.as_string::<i32>().value(0), "baz");
}

#[test]
fn into_parts() {
let mut builder = UnionBuilder::new_dense();
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int8Type>("b", 2).unwrap();
builder.append::<Int32Type>("a", 3).unwrap();
let dense_union = builder.build().unwrap();

let field = [
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int8, false),
];
let (type_ids, offsets, field_type_ids, fields) = dense_union.into_parts();
assert_eq!(field_type_ids, [0, 1]);
assert_eq!(
field.to_vec(),
fields
.iter()
.cloned()
.map(|(field, _)| field)
.collect::<Vec<_>>()
);
assert_eq!(type_ids, [0, 1, 0]);
assert!(offsets.is_some());
assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);

let result = UnionArray::try_new(
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
fields,
);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);

let mut builder = UnionBuilder::new_sparse();
builder.append::<Int32Type>("a", 1).unwrap();
builder.append::<Int8Type>("b", 2).unwrap();
builder.append::<Int32Type>("a", 3).unwrap();
let sparse_union = builder.build().unwrap();

let (type_ids, offsets, field_type_ids, fields) = sparse_union.into_parts();
assert_eq!(type_ids, [0, 1, 0]);
assert!(offsets.is_none());

let result = UnionArray::try_new(
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
fields,
);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);
}

#[test]
fn into_parts_custom_type_ids() {
let mut set_field_type_ids: [i8; 3] = [8, 4, 9];
let data_type = DataType::Union(
UnionFields::new(
set_field_type_ids,
[
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Field::new("floats", DataType::Float64, false),
],
),
UnionMode::Dense,
);
let string_array = StringArray::from(vec!["foo", "bar", "baz"]);
let int_array = Int32Array::from(vec![5, 6, 4]);
let float_array = Float64Array::from(vec![10.0]);
let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);
let data = ArrayData::builder(data_type)
.len(7)
.buffers(vec![type_ids, value_offsets])
.child_data(vec![
string_array.into_data(),
int_array.into_data(),
float_array.into_data(),
])
.build()
.unwrap();
let array = UnionArray::from(data);

let (type_ids, offsets, field_type_ids, fields) = array.into_parts();
set_field_type_ids.sort();
assert_eq!(field_type_ids, set_field_type_ids);
let result = UnionArray::try_new(
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
fields,
);
assert!(result.is_ok());
let array = result.unwrap();
assert_eq!(array.len(), 7);
}
}

0 comments on commit c203785

Please sign in to comment.