Skip to content

Commit

Permalink
Change into_parts output to better match try_new
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrobbel committed Apr 8, 2024
1 parent 23e6019 commit 8eff10f
Showing 1 changed file with 52 additions and 44 deletions.
96 changes: 52 additions & 44 deletions arrow-array/src/array/union_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use arrow_schema::{ArrowError, DataType, Field, UnionFields, UnionMode};
/// Contains the `UnionArray` type.
///
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;

/// An array of [values of varying types](https://arrow.apache.org/docs/format/Columnar.html#union-layout)
Expand Down Expand Up @@ -325,25 +326,37 @@ impl UnionArray {
/// # Example
///
/// ```
/// # use arrow_array::array::UnionArray;
/// # use arrow_array::types::Int32Type;
/// # use arrow_array::builder::UnionBuilder;
/// # use arrow_buffer::ScalarBuffer;
/// # fn main() -> Result<(), arrow_schema::ArrowError> {
/// let mut builder = UnionBuilder::new_dense();
/// builder.append::<Int32Type>("a", 1).unwrap();
/// let union_array = builder.build()?;
/// let (union_fields, union_mode, type_ids, offsets, fields) = union_array.into_parts();
///
/// // Deconstruct into parts
/// let (union_mode, type_ids, offsets, field_type_ids, fields) = union_array.into_parts();
///
/// // Reconstruct from parts
/// let union_array = UnionArray::try_new(
/// &field_type_ids,
/// type_ids.into_inner(),
/// offsets.map(ScalarBuffer::into_inner),
/// fields,
/// );
/// # Ok(())
/// # }
/// ```
#[allow(clippy::type_complexity)]
pub fn into_parts(
self,
) -> (
UnionFields,
UnionMode,
ScalarBuffer<i8>,
Option<ScalarBuffer<i32>>,
Vec<Option<ArrayRef>>,
Vec<i8>,
Vec<(Field, ArrayRef)>,
) {
let Self {
data_type,
Expand All @@ -353,7 +366,21 @@ impl UnionArray {
} = self;
match data_type {
DataType::Union(union_fields, union_mode) => {
(union_fields, union_mode, type_ids, offsets, fields)
let union_fields = union_fields.iter().collect::<HashMap<_, _>>();
let (field_type_ids, fields) = fields
.into_iter()
.enumerate()
.flat_map(|(type_id, array_ref)| {
array_ref.map(|array_ref| {
let type_id = type_id as i8;
(
type_id,
((*Arc::clone(union_fields[&type_id])).clone(), array_ref),
)
})
})
.unzip();
(union_mode, type_ids, offsets, field_type_ids, fields)
}
_ => unreachable!(),
}
Expand Down Expand Up @@ -1254,27 +1281,26 @@ mod tests {
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int8, false),
];
let field_type_ids = [0, 1];
let (union_fields, union_mode, type_ids, offsets, fields) = dense_union.into_parts();
let (union_mode, type_ids, offsets, field_type_ids, fields) = dense_union.into_parts();
assert_eq!(union_mode, UnionMode::Dense);
assert_eq!(field_type_ids, [0, 1]);
assert_eq!(
union_fields,
UnionFields::new(field_type_ids, field.clone())
field.to_vec(),
fields
.iter()
.cloned()
.map(|(field, _)| field)
.collect::<Vec<_>>()
);
assert_eq!(union_mode, UnionMode::Dense);
assert_eq!(type_ids, [0, 1, 0]);
assert!(offsets.is_some());
assert_eq!(offsets.as_ref().unwrap(), &[0, 0, 1]);
assert_eq!(fields.len(), 2);

let result = UnionArray::try_new(
&[0, 1],
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
field
.clone()
.into_iter()
.zip(fields.into_iter().flatten())
.collect(),
fields,
);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);
Expand All @@ -1285,35 +1311,27 @@ mod tests {
builder.append::<Int32Type>("a", 3).unwrap();
let sparse_union = builder.build().unwrap();

let (union_fields, union_mode, type_ids, offsets, fields) = sparse_union.into_parts();
assert_eq!(
union_fields,
UnionFields::new(field_type_ids, field.clone())
);
let (union_mode, type_ids, offsets, field_type_ids, fields) = sparse_union.into_parts();
assert_eq!(union_mode, UnionMode::Sparse);
assert_eq!(type_ids, [0, 1, 0]);
assert!(offsets.is_none());
assert_eq!(fields.len(), 2);

let result = UnionArray::try_new(
&[0, 1],
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
field
.into_iter()
.zip(fields.into_iter().flatten())
.collect(),
fields,
);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 3);
}

#[test]
fn into_parts_custom_type_ids() {
const TYPE_IDS: [i8; 3] = [8, 4, 9];
let mut set_field_type_ids: [i8; 3] = [8, 4, 9];
let data_type = DataType::Union(
UnionFields::new(
TYPE_IDS,
set_field_type_ids,
[
Field::new("strings", DataType::Utf8, false),
Field::new("integers", DataType::Int32, false),
Expand All @@ -1339,28 +1357,18 @@ mod tests {
.unwrap();
let array = UnionArray::from(data);

let (union_fields, union_mode, type_ids, offsets, mut fields) = array.into_parts();
let (union_mode, type_ids, offsets, field_type_ids, fields) = array.into_parts();
assert_eq!(union_mode, UnionMode::Dense);
set_field_type_ids.sort();
assert_eq!(field_type_ids, set_field_type_ids);
let result = UnionArray::try_new(
&TYPE_IDS,
&field_type_ids,
type_ids.into_inner(),
offsets.map(ScalarBuffer::into_inner),
union_fields
.iter()
.map(|(type_id, field)| {
(
(*Arc::clone(field)).clone(),
fields[type_id as usize].take().unwrap(),
)
})
.collect(),
fields,
);
assert!(result.is_ok());
let array = result.unwrap();
assert_eq!(array.len(), 7);
let (_, _, _, _, fields) = array.into_parts();
for type_id in TYPE_IDS {
assert!(fields.get(type_id as usize).is_some_and(Option::is_some))
}
}
}

0 comments on commit 8eff10f

Please sign in to comment.