Skip to content

Commit

Permalink
Change UnionArray constructors (#5623)
Browse files Browse the repository at this point in the history
* Change `UnionArray` constructors

* Fix a comment

* Clippy and avoid using hashmaps

* Additional test

---------

Co-authored-by: Raphael Taylor-Davies <[email protected]>
  • Loading branch information
mbrobbel and tustvold authored May 8, 2024
1 parent 4045fb5 commit b25c441
Show file tree
Hide file tree
Showing 10 changed files with 360 additions and 332 deletions.
372 changes: 209 additions & 163 deletions arrow-array/src/array/union_array.rs

Large diffs are not rendered by default.

78 changes: 39 additions & 39 deletions arrow-array/src/builder/union_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use arrow_buffer::{ArrowNativeType, Buffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, DataType, Field};
use std::any::Any;
use std::collections::HashMap;
use std::collections::BTreeMap;
use std::sync::Arc;

/// `FieldData` is a helper struct to track the state of the fields in the `UnionBuilder`.
#[derive(Debug)]
Expand Down Expand Up @@ -142,7 +143,7 @@ pub struct UnionBuilder {
/// The current number of slots in the array
len: usize,
/// Maps field names to `FieldData` instances which track the builders for that field
fields: HashMap<String, FieldData>,
fields: BTreeMap<String, FieldData>,
/// Builder to keep track of type ids
type_id_builder: Int8BufferBuilder,
/// Builder to keep track of offsets (`None` for sparse unions)
Expand All @@ -165,7 +166,7 @@ impl UnionBuilder {
pub fn with_capacity_dense(capacity: usize) -> Self {
Self {
len: 0,
fields: HashMap::default(),
fields: Default::default(),
type_id_builder: Int8BufferBuilder::new(capacity),
value_offset_builder: Some(Int32BufferBuilder::new(capacity)),
initial_capacity: capacity,
Expand All @@ -176,7 +177,7 @@ impl UnionBuilder {
pub fn with_capacity_sparse(capacity: usize) -> Self {
Self {
len: 0,
fields: HashMap::default(),
fields: Default::default(),
type_id_builder: Int8BufferBuilder::new(capacity),
value_offset_builder: None,
initial_capacity: capacity,
Expand Down Expand Up @@ -274,40 +275,39 @@ impl UnionBuilder {
}

/// Builds this builder creating a new `UnionArray`.
pub fn build(mut self) -> Result<UnionArray, ArrowError> {
let type_id_buffer = self.type_id_builder.finish();
let value_offsets_buffer = self.value_offset_builder.map(|mut b| b.finish());
let mut children = Vec::new();
for (
name,
FieldData {
type_id,
data_type,
mut values_buffer,
slots,
null_buffer_builder: mut bitmap_builder,
},
) in self.fields.into_iter()
{
let buffer = values_buffer.finish();
let arr_data_builder = ArrayDataBuilder::new(data_type.clone())
.add_buffer(buffer)
.len(slots)
.nulls(bitmap_builder.finish());

let arr_data_ref = unsafe { arr_data_builder.build_unchecked() };
let array_ref = make_array(arr_data_ref);
children.push((type_id, (Field::new(name, data_type, false), array_ref)))
}

children.sort_by(|a, b| {
a.0.partial_cmp(&b.0)
.expect("This will never be None as type ids are always i8 values.")
});
let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect();

let type_ids: Vec<i8> = (0_i8..children.len() as i8).collect();

UnionArray::try_new(&type_ids, type_id_buffer, value_offsets_buffer, children)
pub fn build(self) -> Result<UnionArray, ArrowError> {
let mut children = Vec::with_capacity(self.fields.len());
let union_fields = self
.fields
.into_iter()
.map(
|(
name,
FieldData {
type_id,
data_type,
mut values_buffer,
slots,
mut null_buffer_builder,
},
)| {
let array_ref = make_array(unsafe {
ArrayDataBuilder::new(data_type.clone())
.add_buffer(values_buffer.finish())
.len(slots)
.nulls(null_buffer_builder.finish())
.build_unchecked()
});
children.push(array_ref);
(type_id, Arc::new(Field::new(name, data_type, false)))
},
)
.collect();
UnionArray::try_new(
union_fields,
self.type_id_builder.into(),
self.value_offset_builder.map(Into::into),
children,
)
}
}
18 changes: 11 additions & 7 deletions arrow-cast/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ mod tests {
use arrow_array::builder::*;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::Buffer;
use arrow_buffer::ScalarBuffer;
use arrow_schema::*;

use crate::display::array_value_to_string;
Expand Down Expand Up @@ -851,14 +851,18 @@ mod tests {

// Can't use UnionBuilder with non-primitive types, so manually build outer UnionArray
let a_array = Int32Array::from(vec![None, None, None, Some(1234), Some(23)]);
let type_ids = Buffer::from_slice_ref([1_i8, 1, 0, 0, 1]);
let type_ids = [1, 1, 0, 0, 1].into_iter().collect::<ScalarBuffer<i8>>();

let children: Vec<(Field, Arc<dyn Array>)> = vec![
(Field::new("a", DataType::Int32, true), Arc::new(a_array)),
(inner_field.clone(), Arc::new(inner)),
];
let children = vec![Arc::new(a_array) as Arc<dyn Array>, Arc::new(inner)];

let union_fields = [
(0, Arc::new(Field::new("a", DataType::Int32, true))),
(1, Arc::new(inner_field.clone())),
]
.into_iter()
.collect();

let outer = UnionArray::try_new(&[0, 1], type_ids, None, children).unwrap();
let outer = UnionArray::try_new(union_fields, type_ids, None, children).unwrap();

let schema = Schema::new(vec![Field::new_union(
"Teamsters",
Expand Down
101 changes: 44 additions & 57 deletions arrow-flight/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,20 +597,17 @@ fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result<ArrayRef
(DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => {
let union_arr = array.as_any().downcast_ref::<UnionArray>().unwrap();

let (type_ids, fields): (Vec<i8>, Vec<&FieldRef>) = fields.iter().unzip();

Arc::new(UnionArray::try_new(
&type_ids,
union_arr.type_ids().inner().clone(),
fields.clone(),
union_arr.type_ids().clone(),
None,
fields
.iter()
.enumerate()
.map(|(col, field)| {
Ok((
field.as_ref().clone(),
arrow_cast::cast(union_arr.child(col as i8), field.data_type())?,
))
.map(|(type_id, field)| {
Ok(arrow_cast::cast(
union_arr.child(type_id),
field.data_type(),
)?)
})
.collect::<Result<Vec<_>>>()?,
)?)
Expand All @@ -625,10 +622,10 @@ mod tests {
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::*;
use arrow_array::{cast::downcast_array, types::*};
use arrow_buffer::Buffer;
use arrow_buffer::ScalarBuffer;
use arrow_cast::pretty::pretty_format_batches;
use arrow_ipc::MetadataVersion;
use arrow_schema::UnionMode;
use arrow_schema::{UnionFields, UnionMode};
use std::collections::HashMap;

use crate::decode::{DecodedPayload, FlightDataDecoder};
Expand Down Expand Up @@ -849,16 +846,23 @@ mod tests {
true,
)];

let type_ids = vec![0, 1, 2];
let union_fields = vec![
Field::new_list(
"dict_list",
Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
true,
let union_fields = [
(
0,
Arc::new(Field::new_list(
"dict_list",
Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true),
true,
)),
),
Field::new_struct("struct", struct_fields.clone(), true),
Field::new("string", DataType::Utf8, true),
];
(
1,
Arc::new(Field::new_struct("struct", struct_fields.clone(), true)),
),
(2, Arc::new(Field::new("string", DataType::Utf8, true))),
]
.into_iter()
.collect::<UnionFields>();

let struct_fields = vec![Field::new_list(
"dict_list",
Expand All @@ -872,21 +876,15 @@ mod tests {

let arr1 = builder.finish();

let type_id_buffer = Buffer::from_slice_ref([0_i8]);
let type_id_buffer = [0].into_iter().collect::<ScalarBuffer<i8>>();
let arr1 = UnionArray::try_new(
&type_ids,
union_fields.clone(),
type_id_buffer,
None,
vec![
(union_fields[0].clone(), Arc::new(arr1)),
(
union_fields[1].clone(),
new_null_array(union_fields[1].data_type(), 1),
),
(
union_fields[2].clone(),
new_null_array(union_fields[2].data_type(), 1),
),
Arc::new(arr1) as Arc<dyn Array>,
new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
],
)
.unwrap();
Expand All @@ -896,47 +894,36 @@ mod tests {
let arr2 = Arc::new(builder.finish());
let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None);

let type_id_buffer = Buffer::from_slice_ref([1_i8]);
let type_id_buffer = [1].into_iter().collect::<ScalarBuffer<i8>>();
let arr2 = UnionArray::try_new(
&type_ids,
union_fields.clone(),
type_id_buffer,
None,
vec![
(
union_fields[0].clone(),
new_null_array(union_fields[0].data_type(), 1),
),
(union_fields[1].clone(), Arc::new(arr2)),
(
union_fields[2].clone(),
new_null_array(union_fields[2].data_type(), 1),
),
new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
Arc::new(arr2),
new_null_array(union_fields.iter().nth(2).unwrap().1.data_type(), 1),
],
)
.unwrap();

let type_id_buffer = Buffer::from_slice_ref([2_i8]);
let type_id_buffer = [2].into_iter().collect::<ScalarBuffer<i8>>();
let arr3 = UnionArray::try_new(
&type_ids,
union_fields.clone(),
type_id_buffer,
None,
vec![
(
union_fields[0].clone(),
new_null_array(union_fields[0].data_type(), 1),
),
(
union_fields[1].clone(),
new_null_array(union_fields[1].data_type(), 1),
),
(
union_fields[2].clone(),
Arc::new(StringArray::from(vec!["e"])),
),
new_null_array(union_fields.iter().next().unwrap().1.data_type(), 1),
new_null_array(union_fields.iter().nth(1).unwrap().1.data_type(), 1),
Arc::new(StringArray::from(vec!["e"])),
],
)
.unwrap();

let (type_ids, union_fields): (Vec<_>, Vec<_>) = union_fields
.iter()
.map(|(type_id, field_ref)| (type_id, (*Arc::clone(field_ref)).clone()))
.unzip();
let schema = Arc::new(Schema::new(vec![Field::new_union(
"union",
type_ids.clone(),
Expand Down
23 changes: 8 additions & 15 deletions arrow-integration-test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
//!
//! This is not a canonical format, but provides a human-readable way of verifying language implementations
use arrow_buffer::ScalarBuffer;
use hex::decode;
use num::BigInt;
use num::Signed;
Expand Down Expand Up @@ -835,26 +836,18 @@ pub fn array_from_json(
));
};

let offset: Option<Buffer> = json_col.offset.map(|offsets| {
let offsets: Vec<i32> =
offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect();
Buffer::from(&offsets.to_byte_slice())
});
let offset: Option<ScalarBuffer<i32>> = json_col
.offset
.map(|offsets| offsets.iter().map(|v| v.as_i64().unwrap() as i32).collect());

let mut children: Vec<(Field, Arc<dyn Array>)> = vec![];
let mut children = Vec::with_capacity(fields.len());
for ((_, field), col) in fields.iter().zip(json_col.children.unwrap()) {
let array = array_from_json(field, col, dictionaries)?;
children.push((field.as_ref().clone(), array));
children.push(array);
}

let field_type_ids = fields.iter().map(|(id, _)| id).collect::<Vec<_>>();
let array = UnionArray::try_new(
&field_type_ids,
Buffer::from(&type_ids.to_byte_slice()),
offset,
children,
)
.unwrap();
let array =
UnionArray::try_new(fields.clone(), type_ids.into(), offset, children).unwrap();
Ok(Arc::new(array))
}
t => Err(ArrowError::JsonError(format!(
Expand Down
17 changes: 8 additions & 9 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use std::io::{BufReader, Read, Seek, SeekFrom};
use std::sync::Arc;

use arrow_array::*;
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, ScalarBuffer};
use arrow_data::ArrayData;
use arrow_schema::*;

Expand Down Expand Up @@ -214,26 +214,25 @@ fn create_array(
reader.next_buffer()?;
}

let type_ids: Buffer = reader.next_buffer()?[..len].into();
let type_ids: ScalarBuffer<i8> = reader.next_buffer()?.slice_with_length(0, len).into();

let value_offsets = match mode {
UnionMode::Dense => {
let buffer = reader.next_buffer()?;
Some(buffer[..len * 4].into())
let offsets: ScalarBuffer<i32> =
reader.next_buffer()?.slice_with_length(0, len * 4).into();
Some(offsets)
}
UnionMode::Sparse => None,
};

let mut children = Vec::with_capacity(fields.len());
let mut ids = Vec::with_capacity(fields.len());

for (id, field) in fields.iter() {
for (_id, field) in fields.iter() {
let child = create_array(reader, field, variadic_counts, require_alignment)?;
children.push((field.as_ref().clone(), child));
ids.push(id);
children.push(child);
}

let array = UnionArray::try_new(&ids, type_ids, value_offsets, children)?;
let array = UnionArray::try_new(fields.clone(), type_ids, value_offsets, children)?;
Ok(Arc::new(array))
}
Null => {
Expand Down
Loading

0 comments on commit b25c441

Please sign in to comment.