From 924b6e9d0e62ad8cb85419268d8765611a72631e Mon Sep 17 00:00:00 2001 From: Jeffrey <22608443+Jefffrey@users.noreply.github.com> Date: Tue, 14 Nov 2023 08:01:10 +1100 Subject: [PATCH] IPC writer truncated sliced list/map values (#5071) * IPC writer truncated sliced list/map values * Add empty list test * Revert submodule update --- arrow-ipc/src/writer.rs | 429 ++++++++++++++++++++++++++-------------- 1 file changed, 285 insertions(+), 144 deletions(-) diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index a58cbfc51428..1f6bf5f6fa85 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -1139,6 +1139,29 @@ fn get_buffer_element_width(spec: &BufferSpec) -> usize { } } +/// Common functionality for re-encoding offsets. Returns the new offsets as well as +/// original start offset and length for use in slicing child data. +fn reencode_offsets( + offsets: &Buffer, + data: &ArrayData, +) -> (Buffer, usize, usize) { + let offsets_slice: &[O] = offsets.typed_data::(); + let offset_slice = &offsets_slice[data.offset()..data.offset() + data.len() + 1]; + + let start_offset = offset_slice.first().unwrap(); + let end_offset = offset_slice.last().unwrap(); + + let offsets = match start_offset.as_usize() { + 0 => offsets.clone(), + _ => offset_slice.iter().map(|x| *x - *start_offset).collect(), + }; + + let start_offset = start_offset.as_usize(); + let end_offset = end_offset.as_usize(); + + (offsets, start_offset, end_offset - start_offset) +} + /// Returns the values and offsets [`Buffer`] for a ByteArray with offset type `O` /// /// In particular, this handles re-encoding the offsets if they don't start at `0`, @@ -1149,23 +1172,24 @@ fn get_byte_array_buffers(data: &ArrayData) -> (Buffer, Buff return (MutableBuffer::new(0).into(), MutableBuffer::new(0).into()); } - let buffers = data.buffers(); - let offsets: &[O] = buffers[0].typed_data::(); - let offset_slice = &offsets[data.offset()..data.offset() + data.len() + 1]; - - let start_offset = offset_slice.first().unwrap(); - let end_offset = offset_slice.last().unwrap(); + let (offsets, original_start_offset, len) = reencode_offsets::(&data.buffers()[0], data); + let values = data.buffers()[1].slice_with_length(original_start_offset, len); + (offsets, values) +} - let offsets = match start_offset.as_usize() { - 0 => buffers[0].clone(), - _ => offset_slice.iter().map(|x| *x - *start_offset).collect(), - }; +/// Similar logic as [`get_byte_array_buffers()`] but slices the child array instead +/// of a values buffer. +fn get_list_array_buffers(data: &ArrayData) -> (Buffer, ArrayData) { + if data.is_empty() { + return ( + MutableBuffer::new(0).into(), + data.child_data()[0].slice(0, 0), + ); + } - let values = buffers[1].slice_with_length( - start_offset.as_usize(), - end_offset.as_usize() - start_offset.as_usize(), - ); - (offsets, values) + let (offsets, original_start_offset, len) = reencode_offsets::(&data.buffers()[0], data); + let child_data = data.child_data()[0].slice(original_start_offset, len); + (offsets, child_data) } /// Write array data to a vector of bytes @@ -1250,20 +1274,14 @@ fn write_array_data( let byte_width = get_buffer_element_width(spec); let min_length = array_data.len() * byte_width; - if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) { + let buffer_slice = if buffer_need_truncate(array_data.offset(), buffer, spec, min_length) { let byte_offset = array_data.offset() * byte_width; let buffer_length = min(min_length, buffer.len() - byte_offset); - let buffer_slice = &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)]; - offset = write_buffer(buffer_slice, buffers, arrow_data, offset, compression_codec)?; + &buffer.as_slice()[byte_offset..(byte_offset + buffer_length)] } else { - offset = write_buffer( - buffer.as_slice(), - buffers, - arrow_data, - offset, - compression_codec, - )?; - } + buffer.as_slice() + }; + offset = write_buffer(buffer_slice, buffers, arrow_data, offset, compression_codec)?; } else if matches!(data_type, DataType::Boolean) { // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. @@ -1272,6 +1290,39 @@ fn write_array_data( let buffer = &array_data.buffers()[0]; let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?; + } else if matches!( + data_type, + DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) + ) { + assert_eq!(array_data.buffers().len(), 1); + assert_eq!(array_data.child_data().len(), 1); + + // Truncate offsets and the child data to avoid writing unnecessary data + let (offsets, sliced_child_data) = match data_type { + DataType::List(_) => get_list_array_buffers::(array_data), + DataType::Map(_, _) => get_list_array_buffers::(array_data), + DataType::LargeList(_) => get_list_array_buffers::(array_data), + _ => unreachable!(), + }; + offset = write_buffer( + offsets.as_slice(), + buffers, + arrow_data, + offset, + compression_codec, + )?; + offset = write_array_data( + &sliced_child_data, + buffers, + arrow_data, + nodes, + offset, + sliced_child_data.len(), + sliced_child_data.null_count(), + compression_codec, + write_options, + )?; + return Ok(offset); } else { for buffer in array_data.buffers() { offset = write_buffer(buffer, buffers, arrow_data, offset, compression_codec)?; @@ -1372,8 +1423,10 @@ mod tests { use std::io::Seek; use std::sync::Arc; + use arrow_array::builder::GenericListBuilder; + use arrow_array::builder::MapBuilder; use arrow_array::builder::UnionBuilder; - use arrow_array::builder::{ListBuilder, PrimitiveRunBuilder, UInt32Builder}; + use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder}; use arrow_array::types::*; use arrow_schema::DataType; @@ -1382,6 +1435,30 @@ mod tests { use super::*; + fn serialize_file(rb: &RecordBatch) -> Vec { + let mut writer = FileWriter::try_new(vec![], &rb.schema()).unwrap(); + writer.write(rb).unwrap(); + writer.finish().unwrap(); + writer.into_inner().unwrap() + } + + fn deserialize_file(bytes: Vec) -> RecordBatch { + let mut reader = FileReader::try_new(Cursor::new(bytes), None).unwrap(); + reader.next().unwrap().unwrap() + } + + fn serialize_stream(record: &RecordBatch) -> Vec { + let mut stream_writer = StreamWriter::try_new(vec![], &record.schema()).unwrap(); + stream_writer.write(record).unwrap(); + stream_writer.finish().unwrap(); + stream_writer.into_inner().unwrap() + } + + fn deserialize_stream(bytes: Vec) -> RecordBatch { + let mut stream_reader = StreamReader::try_new(Cursor::new(bytes), None).unwrap(); + stream_reader.next().unwrap().unwrap() + } + #[test] #[cfg(feature = "lz4")] fn test_write_empty_record_batch_lz4_compression() { @@ -1407,27 +1484,18 @@ mod tests { file.rewind().unwrap(); { // read file - let mut reader = FileReader::try_new(file, None).unwrap(); - loop { - match reader.next() { - Some(Ok(read_batch)) => { - read_batch - .columns() - .iter() - .zip(record_batch.columns()) - .for_each(|(a, b)| { - assert_eq!(a.data_type(), b.data_type()); - assert_eq!(a.len(), b.len()); - assert_eq!(a.null_count(), b.null_count()); - }); - } - Some(Err(e)) => { - panic!("{}", e); - } - None => { - break; - } - } + let reader = FileReader::try_new(file, None).unwrap(); + for read_batch in reader { + read_batch + .unwrap() + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); } } } @@ -1456,27 +1524,18 @@ mod tests { file.rewind().unwrap(); { // read file - let mut reader = FileReader::try_new(file, None).unwrap(); - loop { - match reader.next() { - Some(Ok(read_batch)) => { - read_batch - .columns() - .iter() - .zip(record_batch.columns()) - .for_each(|(a, b)| { - assert_eq!(a.data_type(), b.data_type()); - assert_eq!(a.len(), b.len()); - assert_eq!(a.null_count(), b.null_count()); - }); - } - Some(Err(e)) => { - panic!("{}", e); - } - None => { - break; - } - } + let reader = FileReader::try_new(file, None).unwrap(); + for read_batch in reader { + read_batch + .unwrap() + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); } } } @@ -1504,27 +1563,18 @@ mod tests { file.rewind().unwrap(); { // read file - let mut reader = FileReader::try_new(file, None).unwrap(); - loop { - match reader.next() { - Some(Ok(read_batch)) => { - read_batch - .columns() - .iter() - .zip(record_batch.columns()) - .for_each(|(a, b)| { - assert_eq!(a.data_type(), b.data_type()); - assert_eq!(a.len(), b.len()); - assert_eq!(a.null_count(), b.null_count()); - }); - } - Some(Err(e)) => { - panic!("{}", e); - } - None => { - break; - } - } + let reader = FileReader::try_new(file, None).unwrap(); + for read_batch in reader { + read_batch + .unwrap() + .columns() + .iter() + .zip(record_batch.columns()) + .for_each(|(a, b)| { + assert_eq!(a.data_type(), b.data_type()); + assert_eq!(a.len(), b.len()); + assert_eq!(a.null_count(), b.null_count()); + }); } } } @@ -1754,20 +1804,6 @@ mod tests { write_union_file(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()); } - fn serialize(record: &RecordBatch) -> Vec { - let buffer: Vec = Vec::new(); - let mut stream_writer = StreamWriter::try_new(buffer, &record.schema()).unwrap(); - stream_writer.write(record).unwrap(); - stream_writer.finish().unwrap(); - stream_writer.into_inner().unwrap() - } - - fn deserialize(bytes: Vec) -> RecordBatch { - let mut stream_reader = - crate::reader::StreamReader::try_new(std::io::Cursor::new(bytes), None).unwrap(); - stream_reader.next().unwrap().unwrap() - } - #[test] fn truncate_ipc_record_batch() { fn create_batch(rows: usize) -> RecordBatch { @@ -1789,14 +1825,16 @@ mod tests { let offset = 2; let record_batch_slice = big_record_batch.slice(offset, length); - assert!(serialize(&big_record_batch).len() > serialize(&small_record_batch).len()); + assert!( + serialize_stream(&big_record_batch).len() > serialize_stream(&small_record_batch).len() + ); assert_eq!( - serialize(&small_record_batch).len(), - serialize(&record_batch_slice).len() + serialize_stream(&small_record_batch).len(), + serialize_stream(&record_batch_slice).len() ); assert_eq!( - deserialize(serialize(&record_batch_slice)), + deserialize_stream(serialize_stream(&record_batch_slice)), record_batch_slice ); } @@ -1817,9 +1855,11 @@ mod tests { let record_batch = create_batch(); let record_batch_slice = record_batch.slice(1, 2); - let deserialized_batch = deserialize(serialize(&record_batch_slice)); + let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); - assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len()); + assert!( + serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() + ); assert!(deserialized_batch.column(0).is_null(0)); assert!(deserialized_batch.column(0).is_valid(1)); @@ -1846,9 +1886,11 @@ mod tests { let record_batch = create_batch(); let record_batch_slice = record_batch.slice(1, 2); - let deserialized_batch = deserialize(serialize(&record_batch_slice)); + let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); - assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len()); + assert!( + serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() + ); assert!(deserialized_batch.column(0).is_valid(0)); assert!(deserialized_batch.column(0).is_null(1)); @@ -1886,9 +1928,11 @@ mod tests { let record_batch = create_batch(); let record_batch_slice = record_batch.slice(1, 2); - let deserialized_batch = deserialize(serialize(&record_batch_slice)); + let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); - assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len()); + assert!( + serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() + ); let structs = deserialized_batch .column(0) @@ -1913,9 +1957,11 @@ mod tests { let record_batch = create_batch(); let record_batch_slice = record_batch.slice(0, 1); - let deserialized_batch = deserialize(serialize(&record_batch_slice)); + let deserialized_batch = deserialize_stream(serialize_stream(&record_batch_slice)); - assert!(serialize(&record_batch).len() > serialize(&record_batch_slice).len()); + assert!( + serialize_stream(&record_batch).len() > serialize_stream(&record_batch_slice).len() + ); assert_eq!(record_batch_slice, deserialized_batch); } @@ -1996,13 +2042,8 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(bools)]).unwrap(); let batch = batch.slice(offset, length); - let mut writer = StreamWriter::try_new(Vec::::new(), &schema).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - let data = writer.into_inner().unwrap(); - - let mut reader = StreamReader::try_new(Cursor::new(data), None).unwrap(); - let batch2 = reader.next().unwrap().unwrap(); + let data = serialize_stream(&batch); + let batch2 = deserialize_stream(data); assert_eq!(batch, batch2); } @@ -2060,37 +2101,137 @@ mod tests { } } + fn generate_list_data() -> GenericListArray { + let mut ls = GenericListBuilder::::new(UInt32Builder::new()); + + for i in 0..100_000 { + for value in [i, i, i] { + ls.values().append_value(value); + } + ls.append(true) + } + + ls.finish() + } + + fn generate_nested_list_data() -> GenericListArray { + let mut ls = + GenericListBuilder::::new(GenericListBuilder::::new(UInt32Builder::new())); + + for _i in 0..10_000 { + for j in 0..10 { + for value in [j, j, j, j] { + ls.values().values().append_value(value); + } + ls.values().append(true) + } + ls.append(true); + } + + ls.finish() + } + + fn generate_map_array_data() -> MapArray { + let keys_builder = UInt32Builder::new(); + let values_builder = UInt32Builder::new(); + + let mut builder = MapBuilder::new(None, keys_builder, values_builder); + + for i in 0..100_000 { + for _j in 0..3 { + builder.keys().append_value(i); + builder.values().append_value(i * 2); + } + builder.append(true).unwrap(); + } + + builder.finish() + } + + /// Ensure when serde full & sliced versions they are equal to original input. + /// Also ensure serialized sliced version is significantly smaller than serialized full. + fn roundtrip_ensure_sliced_smaller(in_batch: RecordBatch, expected_size_factor: usize) { + // test both full and sliced versions + let in_sliced = in_batch.slice(999, 1); + + let bytes_batch = serialize_file(&in_batch); + let bytes_sliced = serialize_file(&in_sliced); + + // serializing 1 row should be significantly smaller than serializing 100,000 + assert!(bytes_sliced.len() < (bytes_batch.len() / expected_size_factor)); + + // ensure both are still valid and equal to originals + let out_batch = deserialize_file(bytes_batch); + assert_eq!(in_batch, out_batch); + + let out_sliced = deserialize_file(bytes_sliced); + assert_eq!(in_sliced, out_sliced); + } + #[test] fn encode_lists() { let val_inner = Field::new("item", DataType::UInt32, true); - let val_list_field = Field::new_list("val", val_inner, false); + let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); + let schema = Arc::new(Schema::new(vec![val_list_field])); + + let values = Arc::new(generate_list_data::()); + + let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + roundtrip_ensure_sliced_smaller(in_batch, 1000); + } + + #[test] + fn encode_empty_list() { + let val_inner = Field::new("item", DataType::UInt32, true); + let val_list_field = Field::new("val", DataType::List(Arc::new(val_inner)), false); + let schema = Arc::new(Schema::new(vec![val_list_field])); + + let values = Arc::new(generate_list_data::()); + let in_batch = RecordBatch::try_new(schema, vec![values]) + .unwrap() + .slice(999, 0); + let out_batch = deserialize_file(serialize_file(&in_batch)); + assert_eq!(in_batch, out_batch); + } + + #[test] + fn encode_large_lists() { + let val_inner = Field::new("item", DataType::UInt32, true); + let val_list_field = Field::new("val", DataType::LargeList(Arc::new(val_inner)), false); let schema = Arc::new(Schema::new(vec![val_list_field])); - let values = { - let u32 = UInt32Builder::new(); - let mut ls = ListBuilder::new(u32); + let values = Arc::new(generate_list_data::()); - for list in [vec![1u32, 2, 3], vec![4, 5, 6], vec![7, 8, 9, 10]] { - for value in list { - ls.values().append_value(value); - } - ls.append(true) - } + // ensure when serde full & sliced versions they are equal to original input + // also ensure serialized sliced version is significantly smaller than serialized full + let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + roundtrip_ensure_sliced_smaller(in_batch, 1000); + } - ls.finish() - }; + #[test] + fn encode_nested_lists() { + let inner_int = Arc::new(Field::new("item", DataType::UInt32, true)); + let inner_list_field = Arc::new(Field::new("item", DataType::List(inner_int), true)); + let list_field = Field::new("val", DataType::List(inner_list_field), true); + let schema = Arc::new(Schema::new(vec![list_field])); - let batch = RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(values)]).unwrap(); - let batch = batch.slice(1, 1); + let values = Arc::new(generate_nested_list_data::()); - let mut writer = FileWriter::try_new(Vec::::new(), &schema).unwrap(); - writer.write(&batch).unwrap(); - writer.finish().unwrap(); - let data = writer.into_inner().unwrap(); + let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + roundtrip_ensure_sliced_smaller(in_batch, 1000); + } - let mut reader = FileReader::try_new(Cursor::new(data), None).unwrap(); - let batch2 = reader.next().unwrap().unwrap(); - assert_eq!(batch, batch2); + #[test] + fn encode_map_array() { + let keys = Arc::new(Field::new("keys", DataType::UInt32, false)); + let values = Arc::new(Field::new("values", DataType::UInt32, true)); + let map_field = Field::new_map("map", "entries", keys, values, false, true); + let schema = Arc::new(Schema::new(vec![map_field])); + + let values = Arc::new(generate_map_array_data()); + + let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); + roundtrip_ensure_sliced_smaller(in_batch, 1000); } }