From d16cf2331a8f390491ff1928df68e9b56163a1c8 Mon Sep 17 00:00:00 2001 From: Xiangpeng Hao Date: Wed, 27 Mar 2024 16:32:03 +0000 Subject: [PATCH] Add tests and fix bugs with dict types --- arrow-ipc/src/reader.rs | 107 ++++++++++++++++++++++++++++++++++++++++ arrow-ipc/src/writer.rs | 64 +++++++++--------------- 2 files changed, 129 insertions(+), 42 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index fd7d6be3af44..b219821086a5 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -1801,6 +1801,113 @@ mod tests { assert_eq!(input_batch, output_batch); } + const LONG_TEST_STRING: &str = + "This is a long string to make sure binary view array handles it"; + + #[test] + fn test_roundtrip_view_types() { + let schema = Schema::new(vec![ + Field::new("field_1", DataType::BinaryView, true), + Field::new("field_2", DataType::Utf8, true), + Field::new("field_3", DataType::Utf8View, true), + ]); + let bin_values: Vec> = vec![ + Some(b"foo"), + Some(b"bar"), + Some(LONG_TEST_STRING.as_bytes()), + ]; + let utf8_values: Vec> = vec![Some("foo"), Some("bar"), Some(LONG_TEST_STRING)]; + let bin_view_array = BinaryViewArray::from_iter(bin_values); + let utf8_array = StringArray::from_iter(utf8_values.iter()); + let utf8_view_array = StringViewArray::from_iter(utf8_values); + let record_batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(bin_view_array), + Arc::new(utf8_array), + Arc::new(utf8_view_array), + ], + ) + .unwrap(); + + assert_eq!(record_batch, roundtrip_ipc(&record_batch)); + assert_eq!(record_batch, roundtrip_ipc_stream(&record_batch)); + + let sliced_batch = record_batch.slice(1, 2); + assert_eq!(sliced_batch, roundtrip_ipc(&sliced_batch)); + assert_eq!(sliced_batch, roundtrip_ipc_stream(&sliced_batch)); + } + + #[test] + fn test_roundtrip_view_types_nested_dict() { + let bin_values: Vec> = vec![ + Some(b"foo"), + Some(b"bar"), + Some(LONG_TEST_STRING.as_bytes()), + Some(b"field"), + ]; + let utf8_values: Vec> = vec![ + Some("foo"), + Some("bar"), + Some(LONG_TEST_STRING), + Some("field"), + ]; + let bin_view_array = Arc::new(BinaryViewArray::from_iter(bin_values)); + let utf8_view_array = Arc::new(StringViewArray::from_iter(utf8_values)); + + let key_dict_keys = Int8Array::from_iter_values([0, 0, 1, 2, 0, 1, 3]); + let key_dict_array = DictionaryArray::new(key_dict_keys, utf8_view_array.clone()); + let keys_field = Arc::new(Field::new_dict( + "keys", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8View)), + true, + 1, + false, + )); + + let value_dict_keys = Int8Array::from_iter_values([0, 3, 0, 1, 2, 0, 1]); + let value_dict_array = DictionaryArray::new(value_dict_keys, bin_view_array); + let values_field = Arc::new(Field::new_dict( + "values", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::BinaryView)), + true, + 2, + false, + )); + let entry_struct = StructArray::from(vec![ + (keys_field, make_array(key_dict_array.into_data())), + (values_field, make_array(value_dict_array.into_data())), + ]); + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + entry_struct.data_type().clone(), + false, + )), + false, + ); + let entry_offsets = Buffer::from_slice_ref([0, 2, 4, 7]); + let map_data = ArrayData::builder(map_data_type) + .len(3) + .add_buffer(entry_offsets) + .add_child_data(entry_struct.into_data()) + .build() + .unwrap(); + let map_array = MapArray::from(map_data); + + let dict_keys = Int8Array::from_iter_values([0, 1, 0, 1, 1, 2, 0, 1, 2]); + let dict_dict_array = DictionaryArray::new(dict_keys, Arc::new(map_array)); + let schema = Arc::new(Schema::new(vec![Field::new( + "f1", + dict_dict_array.data_type().clone(), + false, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(dict_dict_array)]).unwrap(); + assert_eq!(batch, roundtrip_ipc(&batch)); + assert_eq!(batch, roundtrip_ipc_stream(&batch)); + } + #[test] fn test_no_columns_batch() { let schema = Arc::new(Schema::empty()); diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index d570a198e9cb..11bf18122d65 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -29,10 +29,7 @@ use flatbuffers::FlatBufferBuilder; use arrow_array::builder::BufferBuilder; use arrow_array::cast::*; -use arrow_array::types::{ - Int16Type, Int32Type, Int64Type, Int8Type, RunEndIndexType, UInt16Type, UInt32Type, UInt64Type, - UInt8Type, -}; +use arrow_array::types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}; use arrow_array::*; use arrow_buffer::bit_util; use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; @@ -431,7 +428,7 @@ impl IpcDataGenerator { write_options, )?; - append_variadic_buffer_counts(&mut variadic_buffer_counts, array); + append_variadic_buffer_counts(&mut variadic_buffer_counts, &array_data); } // pad the tail of body data let len = arrow_data.len(); @@ -518,6 +515,9 @@ impl IpcDataGenerator { write_options, )?; + let mut variadic_buffer_counts = vec![]; + append_variadic_buffer_counts(&mut variadic_buffer_counts, array_data); + // pad the tail of body data let len = arrow_data.len(); let pad_len = pad_to_8(len as u32); @@ -526,6 +526,11 @@ impl IpcDataGenerator { // write data let buffers = fbb.create_vector(&buffers); let nodes = fbb.create_vector(&nodes); + let variadic_buffer = if variadic_buffer_counts.is_empty() { + None + } else { + Some(fbb.create_vector(&variadic_buffer_counts)) + }; let root = { let mut batch_builder = crate::RecordBatchBuilder::new(&mut fbb); @@ -535,6 +540,9 @@ impl IpcDataGenerator { if let Some(c) = compression { batch_builder.add_compression(c); } + if let Some(v) = variadic_buffer { + batch_builder.add_variadicBufferCounts(v); + } batch_builder.finish() }; @@ -564,50 +572,22 @@ impl IpcDataGenerator { } } -fn append_variadic_buffer_counts(counts: &mut Vec, array: &dyn Array) { +fn append_variadic_buffer_counts(counts: &mut Vec, array: &ArrayData) { match array.data_type() { DataType::BinaryView | DataType::Utf8View => { // The spec documents the counts only includes the variadic buffers, not the view/null buffers. // https://arrow.apache.org/docs/format/Columnar.html#variadic-buffers - counts.push(array.to_data().buffers().len() as i64 - 1); - } - DataType::Struct(_) => { - let array = array.as_struct(); - for column in array.columns() { - append_variadic_buffer_counts(counts, column.as_ref()); - } + counts.push(array.buffers().len() as i64 - 1); } - DataType::LargeList(_) => { - let array: &LargeListArray = array.as_list(); - append_variadic_buffer_counts(counts, array.values()); + DataType::Dictionary(_, _) => { + // Dictionary types are handled in `encode_dictionaries`. + return; } - DataType::List(_) => { - let array: &ListArray = array.as_list(); - append_variadic_buffer_counts(counts, array.values()); - } - DataType::FixedSizeList(_, _) => { - let array = array.as_fixed_size_list(); - append_variadic_buffer_counts(counts, array.values()); - } - DataType::Dictionary(kt, _) => { - macro_rules! set_subarray_counts { - ($array:expr, $counts:expr, $type:ty, $variant:ident) => { - if &DataType::$variant == kt.as_ref() { - let array: &DictionaryArray<$type> = $array.as_dictionary(); - append_variadic_buffer_counts($counts, array.values()); - } - }; + _ => { + for child in array.child_data() { + append_variadic_buffer_counts(counts, child) } - set_subarray_counts!(array, counts, Int8Type, Int8); - set_subarray_counts!(array, counts, Int16Type, Int16); - set_subarray_counts!(array, counts, Int32Type, Int32); - set_subarray_counts!(array, counts, Int64Type, Int64); - set_subarray_counts!(array, counts, UInt8Type, UInt8); - set_subarray_counts!(array, counts, UInt16Type, UInt16); - set_subarray_counts!(array, counts, UInt32Type, UInt32); - set_subarray_counts!(array, counts, UInt64Type, UInt64); } - _ => {} } } @@ -1883,7 +1863,7 @@ mod tests { } #[test] - fn test_write_binary_view() { + fn test_write_view_types() { const LONG_TEST_STRING: &str = "This is a long string to make sure binary view array handles it"; let schema = Schema::new(vec![