Skip to content

Commit

Permalink
Add tests and fix bugs with dict types
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Mar 27, 2024
1 parent d121ce6 commit d16cf23
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 42 deletions.
107 changes: 107 additions & 0 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1801,6 +1801,113 @@ mod tests {
assert_eq!(input_batch, output_batch);
}

const LONG_TEST_STRING: &str =
"This is a long string to make sure binary view array handles it";

#[test]
fn test_roundtrip_view_types() {
let schema = Schema::new(vec![
Field::new("field_1", DataType::BinaryView, true),
Field::new("field_2", DataType::Utf8, true),
Field::new("field_3", DataType::Utf8View, true),
]);
let bin_values: Vec<Option<&[u8]>> = vec![
Some(b"foo"),
Some(b"bar"),
Some(LONG_TEST_STRING.as_bytes()),
];
let utf8_values: Vec<Option<&str>> = vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)];
let bin_view_array = BinaryViewArray::from_iter(bin_values);
let utf8_array = StringArray::from_iter(utf8_values.iter());
let utf8_view_array = StringViewArray::from_iter(utf8_values);
let record_batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(bin_view_array),
Arc::new(utf8_array),
Arc::new(utf8_view_array),
],
)
.unwrap();

assert_eq!(record_batch, roundtrip_ipc(&record_batch));
assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch));

let sliced_batch = record_batch.slice(1, 2);
assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch));
assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch));
}

#[test]
fn test_roundtrip_view_types_nested_dict() {
let bin_values: Vec<Option<&[u8]>> = vec![
Some(b"foo"),
Some(b"bar"),
Some(LONG_TEST_STRING.as_bytes()),
Some(b"field"),
];
let utf8_values: Vec<Option<&str>> = vec![
Some("foo"),
Some("bar"),
Some(LONG_TEST_STRING),
Some("field"),
];
let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values));
let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values));

let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]);
let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone());
let keys_field = Arc::new(Field::new_dict(
"keys",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)),
true,
1,
false,
));

let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]);
let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array);
let values_field = Arc::new(Field::new_dict(
"values",
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)),
true,
2,
false,
));
let entry_struct = StructArray::from(vec![
(keys_field, make_array(key_dict_array.into_data())),
(values_field, make_array(value_dict_array.into_data())),
]);

let map_data_type = DataType::Map(
Arc::new(Field::new(
"entries",
entry_struct.data_type().clone(),
false,
)),
false,
);
let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]);
let map_data = ArrayData::builder(map_data_type)
.len(3)
.add_buffer(entry_offsets)
.add_child_data(entry_struct.into_data())
.build()
.unwrap();
let map_array = MapArray::from(map_data);

let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]);
let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array));
let schema = Arc::new(Schema::new(vec![Field::new(
"f1",
dict_dict_array.data_type().clone(),
false,
)]));
let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap();
assert_eq!(batch, roundtrip_ipc(&batch));
assert_eq!(batch, roundtrip_ipc_stream(&batch));
}

#[test]
fn test_no_columns_batch() {
let schema = Arc::new(Schema::empty());
Expand Down
64 changes: 22 additions & 42 deletions arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ use flatbuffers::FlatBufferBuilder;

use arrow_array::builder::BufferBuilder;
use arrow_array::cast::*;
use arrow_array::types::{
Int16Type, Int32Type, Int64Type, Int8Type, RunEndIndexType, UInt16Type, UInt32Type, UInt64Type,
UInt8Type,
};
use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType};
use arrow_array::*;
use arrow_buffer::bit_util;
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
Expand Down Expand Up @@ -431,7 +428,7 @@ impl IpcDataGenerator {
write_options,
)?;

append_variadic_buffer_counts(&mut variadic_buffer_counts, array);
append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data);
}
// pad the tail of body data
let len = arrow_data.len();
Expand Down Expand Up @@ -518,6 +515,9 @@ impl IpcDataGenerator {
write_options,
)?;

let mut variadic_buffer_counts = vec![];
append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data);

// pad the tail of body data
let len = arrow_data.len();
let pad_len = pad_to_8(len as u32);
Expand All @@ -526,6 +526,11 @@ impl IpcDataGenerator {
// write data
let buffers = fbb.create_vector(&buffers);
let nodes = fbb.create_vector(&nodes);
let variadic_buffer = if variadic_buffer_counts.is_empty() {
None
} else {
Some(fbb.create_vector(&variadic_buffer_counts))
};

let root = {
let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb);
Expand All @@ -535,6 +540,9 @@ impl IpcDataGenerator {
if let Some(c) = compression {
batch_builder.add_compression(c);
}
if let Some(v) = variadic_buffer {
batch_builder.add_variadicBufferCounts(v);
}
batch_builder.finish()
};

Expand Down Expand Up @@ -564,50 +572,22 @@ impl IpcDataGenerator {
}
}

fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &dyn Array) {
fn append_variadic_buffer_counts(counts: &mut Vec<i64>, array: &ArrayData) {
match array.data_type() {
DataType::BinaryView | DataType::Utf8View => {
// The spec documents the counts only includes the variadic buffers, not the view/null buffers.
// https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers
counts.push(array.to_data().buffers().len() as i64 - 1);
}
DataType::Struct(_) => {
let array = array.as_struct();
for column in array.columns() {
append_variadic_buffer_counts(counts, column.as_ref());
}
counts.push(array.buffers().len() as i64 - 1);
}
DataType::LargeList(_) => {
let array: &LargeListArray = array.as_list();
append_variadic_buffer_counts(counts, array.values());
DataType::Dictionary(_, _) => {
// Dictionary types are handled in `encode_dictionaries`.
return;
}
DataType::List(_) => {
let array: &ListArray = array.as_list();
append_variadic_buffer_counts(counts, array.values());
}
DataType::FixedSizeList(_, _) => {
let array = array.as_fixed_size_list();
append_variadic_buffer_counts(counts, array.values());
}
DataType::Dictionary(kt, _) => {
macro_rules! set_subarray_counts {
($array:expr, $counts:expr, $type:ty, $variant:ident) => {
if &DataType::$variant == kt.as_ref() {
let array: &DictionaryArray<$type> = $array.as_dictionary();
append_variadic_buffer_counts($counts, array.values());
}
};
_ => {
for child in array.child_data() {
append_variadic_buffer_counts(counts, child)
}
set_subarray_counts!(array, counts, Int8Type, Int8);
set_subarray_counts!(array, counts, Int16Type, Int16);
set_subarray_counts!(array, counts, Int32Type, Int32);
set_subarray_counts!(array, counts, Int64Type, Int64);
set_subarray_counts!(array, counts, UInt8Type, UInt8);
set_subarray_counts!(array, counts, UInt16Type, UInt16);
set_subarray_counts!(array, counts, UInt32Type, UInt32);
set_subarray_counts!(array, counts, UInt64Type, UInt64);
}
_ => {}
}
}

Expand Down Expand Up @@ -1883,7 +1863,7 @@ mod tests {
}

#[test]
fn test_write_binary_view() {
fn test_write_view_types() {
const LONG_TEST_STRING: &str =
"This is a long string to make sure binary view array handles it";
let schema = Schema::new(vec![
Expand Down

0 comments on commit d16cf23

Please sign in to comment.