From eddef43d1cb46c1287da187ea1d86b0e1dc35a13 Mon Sep 17 00:00:00 2001 From: Howard Zuo Date: Thu, 4 Apr 2024 06:21:22 -0400 Subject: [PATCH] Optionally require alignment when reading IPC, respect alignment when writing (#5554) * fix * Remove redundant alignment check * fix typo * Add comment about randomized testing * add doc comment on enforce_zero_copy * address alamb feedback * be explicit * fix unit tests * fix arrow-flight tests * clippy * hide api change + add require_alignment to StreamDecoder too * Preserve docs --------- Co-authored-by: Raphael Taylor-Davies --- arrow-buffer/src/buffer/scalar.rs | 2 +- arrow-flight/src/encode.rs | 6 +- arrow-ipc/src/reader.rs | 241 +++++++++++++++++++++++++----- arrow-ipc/src/reader/stream.rs | 27 +++- arrow-ipc/src/writer.rs | 231 ++++++++++++++++++++++++---- 5 files changed, 431 insertions(+), 76 deletions(-) diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 2019cc79830d..343b8549e93d 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -63,7 +63,7 @@ impl ScalarBuffer { /// This method will panic if /// /// * `offset` or `len` would result in overflow - /// * `buffer` is not aligned to a multiple of `std::mem::size_of::` + /// * `buffer` is not aligned to a multiple of `std::mem::align_of::` /// * `bytes` is not large enough for the requested slice pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { let size = std::mem::size_of::(); diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index efd688129485..7604f3cd4d62 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -627,6 +627,7 @@ mod tests { use arrow_array::{cast::downcast_array, types::*}; use arrow_buffer::Buffer; use arrow_cast::pretty::pretty_format_batches; + use arrow_ipc::MetadataVersion; use arrow_schema::UnionMode; use std::collections::HashMap; @@ -638,7 +639,8 @@ mod tests { /// ensure only the batch's used data (not the allocated data) is sent /// fn test_encode_flight_data() { - let options = IpcWriteOptions::default(); + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + let options = IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap(); let c1 = UInt32Array::from(vec![1, 2, 3, 4, 5, 6]); let batch = RecordBatch::try_from_iter(vec![("a", Arc::new(c1) as ArrayRef)]) @@ -1343,6 +1345,8 @@ mod tests { let mut stream = FlightDataEncoderBuilder::new() .with_max_flight_data_size(max_flight_data_size) + // use 8-byte alignment - default alignment is 64 which produces bigger ipc data + .with_options(IpcWriteOptions::try_new(8, false, MetadataVersion::V5).unwrap()) .build(futures::stream::iter([Ok(batch.clone())])); let mut i = 0; diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 4591777c1e37..8eac17e20761 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -78,6 +78,7 @@ fn create_array( reader: &mut ArrayReader, field: &Field, variadic_counts: &mut VecDeque, + require_alignment: bool, ) -> Result { let data_type = field.data_type(); match data_type { @@ -89,6 +90,7 @@ fn create_array( reader.next_buffer()?, reader.next_buffer()?, ], + require_alignment, ), BinaryView | Utf8View => { let count = variadic_counts @@ -100,24 +102,42 @@ fn create_array( let buffers = (0..count) .map(|_| reader.next_buffer()) .collect::, _>>()?; - create_primitive_array(reader.next_node(field)?, data_type, &buffers) + create_primitive_array( + reader.next_node(field)?, + data_type, + &buffers, + require_alignment, + ) } FixedSizeBinary(_) => create_primitive_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], + require_alignment, ), List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?, reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts)?; - create_list_array(list_node, data_type, &list_buffers, values) + let values = create_array(reader, list_field, variadic_counts, require_alignment)?; + create_list_array( + list_node, + data_type, + &list_buffers, + values, + require_alignment, + ) } FixedSizeList(ref list_field, _) => { let list_node = reader.next_node(field)?; let list_buffers = [reader.next_buffer()?]; - let values = create_array(reader, list_field, variadic_counts)?; - create_list_array(list_node, data_type, &list_buffers, values) + let values = create_array(reader, list_field, variadic_counts, require_alignment)?; + create_list_array( + list_node, + data_type, + &list_buffers, + values, + require_alignment, + ) } Struct(struct_fields) => { let struct_node = reader.next_node(field)?; @@ -128,7 +148,7 @@ fn create_array( // TODO investigate whether just knowing the number of buffers could // still work for struct_field in struct_fields { - let child = create_array(reader, struct_field, variadic_counts)?; + let child = create_array(reader, struct_field, variadic_counts, require_alignment)?; struct_arrays.push((struct_field.clone(), child)); } let null_count = struct_node.null_count() as usize; @@ -142,18 +162,24 @@ fn create_array( } RunEndEncoded(run_ends_field, values_field) => { let run_node = reader.next_node(field)?; - let run_ends = create_array(reader, run_ends_field, variadic_counts)?; - let values = create_array(reader, values_field, variadic_counts)?; + let run_ends = + create_array(reader, run_ends_field, variadic_counts, require_alignment)?; + let values = create_array(reader, values_field, variadic_counts, require_alignment)?; let run_array_length = run_node.length() as usize; - let data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(run_array_length) .offset(0) .add_child_data(run_ends.into_data()) - .add_child_data(values.into_data()) - .build_aligned()?; + .add_child_data(values.into_data()); - Ok(make_array(data)) + let array_data = if require_alignment { + builder.build()? + } else { + builder.build_aligned()? + }; + + Ok(make_array(array_data)) } // Create dictionary array from RecordBatch Dictionary(_, _) => { @@ -170,7 +196,13 @@ fn create_array( )) })?; - create_dictionary_array(index_node, data_type, &index_buffers, value_array.clone()) + create_dictionary_array( + index_node, + data_type, + &index_buffers, + value_array.clone(), + require_alignment, + ) } Union(fields, mode) => { let union_node = reader.next_node(field)?; @@ -196,7 +228,7 @@ fn create_array( let mut ids = Vec::with_capacity(fields.len()); for (id, field) in fields.iter() { - let child = create_array(reader, field, variadic_counts)?; + let child = create_array(reader, field, variadic_counts, require_alignment)?; children.push((field.as_ref().clone(), child)); ids.push(id); } @@ -215,18 +247,24 @@ fn create_array( ))); } - let data = ArrayData::builder(data_type.clone()) + let builder = ArrayData::builder(data_type.clone()) .len(length as usize) - .offset(0) - .build_aligned() - .unwrap(); + .offset(0); + + let array_data = if require_alignment { + builder.build()? + } else { + builder.build_aligned()? + }; + // no buffer increases - Ok(Arc::new(NullArray::from(data))) + Ok(Arc::new(NullArray::from(array_data))) } _ => create_primitive_array( reader.next_node(field)?, data_type, &[reader.next_buffer()?, reader.next_buffer()?], + require_alignment, ), } } @@ -237,34 +275,38 @@ fn create_primitive_array( field_node: &FieldNode, data_type: &DataType, buffers: &[Buffer], + require_alignment: bool, ) -> Result { let length = field_node.length() as usize; let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); - let array_data = match data_type { + let builder = match data_type { Utf8 | Binary | LargeBinary | LargeUtf8 => { // read 3 buffers: null buffer (optional), offsets buffer and data buffer ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..3].to_vec()) .null_bit_buffer(null_buffer) - .build_aligned()? } BinaryView | Utf8View => ArrayData::builder(data_type.clone()) .len(length) .buffers(buffers[1..].to_vec()) - .null_bit_buffer(null_buffer) - .build_aligned()?, + .null_bit_buffer(null_buffer), _ if data_type.is_primitive() || matches!(data_type, Boolean | FixedSizeBinary(_)) => { // read 2 buffers: null buffer (optional) and data buffer ArrayData::builder(data_type.clone()) .len(length) .add_buffer(buffers[1].clone()) .null_bit_buffer(null_buffer) - .build_aligned()? } t => unreachable!("Data type {:?} either unsupported or not primitive", t), }; + let array_data = if require_alignment { + builder.build()? + } else { + builder.build_aligned()? + }; + Ok(make_array(array_data)) } @@ -275,6 +317,7 @@ fn create_list_array( data_type: &DataType, buffers: &[Buffer], child_array: ArrayRef, + require_alignment: bool, ) -> Result { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); let length = field_node.length() as usize; @@ -293,7 +336,14 @@ fn create_list_array( _ => unreachable!("Cannot create list or map array from {:?}", data_type), }; - Ok(make_array(builder.build_aligned()?)) + + let array_data = if require_alignment { + builder.build()? + } else { + builder.build_aligned()? + }; + + Ok(make_array(array_data)) } /// Reads the correct number of buffers based on list type and null_count, and creates a @@ -303,6 +353,7 @@ fn create_dictionary_array( data_type: &DataType, buffers: &[Buffer], value_array: ArrayRef, + require_alignment: bool, ) -> Result { if let Dictionary(_, _) = *data_type { let null_buffer = (field_node.null_count() > 0).then_some(buffers[0].clone()); @@ -312,7 +363,13 @@ fn create_dictionary_array( .add_child_data(value_array.into_data()) .null_bit_buffer(null_buffer); - Ok(make_array(builder.build_aligned()?)) + let array_data = if require_alignment { + builder.build()? + } else { + builder.build_aligned()? + }; + + Ok(make_array(array_data)) } else { unreachable!("Cannot create dictionary array from {:?}", data_type) } @@ -428,7 +485,16 @@ impl<'a> ArrayReader<'a> { } } -/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema` +/// Creates a record batch from binary data using the `crate::RecordBatch` indexes and the `Schema`. +/// +/// If `require_alignment` is true, this function will return an error if any array data in the +/// input `buf` is not properly aligned. +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct [`arrow_data::ArrayData`]. +/// +/// If `require_alignment` is false, this function will automatically allocate a new aligned buffer +/// and copy over the data if any array data in the input `buf` is not properly aligned. +/// (Properly aligned array data will remain zero-copy.) +/// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct [`arrow_data::ArrayData`]. pub fn read_record_batch( buf: &Buffer, batch: crate::RecordBatch, @@ -436,6 +502,38 @@ pub fn read_record_batch( dictionaries_by_id: &HashMap, projection: Option<&[usize]>, metadata: &MetadataVersion, +) -> Result { + read_record_batch_impl( + buf, + batch, + schema, + dictionaries_by_id, + projection, + metadata, + false, + ) +} + +/// Read the dictionary from the buffer and provided metadata, +/// updating the `dictionaries_by_id` with the resulting dictionary +pub fn read_dictionary( + buf: &Buffer, + batch: crate::DictionaryBatch, + schema: &Schema, + dictionaries_by_id: &mut HashMap, + metadata: &MetadataVersion, +) -> Result<(), ArrowError> { + read_dictionary_impl(buf, batch, schema, dictionaries_by_id, metadata, false) +} + +fn read_record_batch_impl( + buf: &Buffer, + batch: crate::RecordBatch, + schema: SchemaRef, + dictionaries_by_id: &HashMap, + projection: Option<&[usize]>, + metadata: &MetadataVersion, + require_alignment: bool, ) -> Result { let buffers = batch.buffers().ok_or_else(|| { ArrowError::IpcError("Unable to get buffers from IPC RecordBatch".to_string()) @@ -469,7 +567,8 @@ pub fn read_record_batch( for (idx, field) in schema.fields().iter().enumerate() { // Create array for projected field if let Some(proj_idx) = projection.iter().position(|p| p == &idx) { - let child = create_array(&mut reader, field, &mut variadic_counts)?; + let child = + create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; arrays.push((proj_idx, child)); } else { reader.skip_field(field, &mut variadic_counts)?; @@ -486,7 +585,7 @@ pub fn read_record_batch( let mut children = vec![]; // keep track of index as lists require more than one node for field in schema.fields() { - let child = create_array(&mut reader, field, &mut variadic_counts)?; + let child = create_array(&mut reader, field, &mut variadic_counts, require_alignment)?; children.push(child); } assert!(variadic_counts.is_empty()); @@ -494,14 +593,13 @@ pub fn read_record_batch( } } -/// Read the dictionary from the buffer and provided metadata, -/// updating the `dictionaries_by_id` with the resulting dictionary -pub fn read_dictionary( +fn read_dictionary_impl( buf: &Buffer, batch: crate::DictionaryBatch, schema: &Schema, dictionaries_by_id: &mut HashMap, metadata: &MetadataVersion, + require_alignment: bool, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -524,13 +622,14 @@ pub fn read_dictionary( let value = value_type.as_ref().clone(); let schema = Schema::new(vec![Field::new("", value, true)]); // Read a single column - let record_batch = read_record_batch( + let record_batch = read_record_batch_impl( buf, batch.data().unwrap(), Arc::new(schema), dictionaries_by_id, None, metadata, + require_alignment, )?; Some(record_batch.column(0).clone()) } @@ -655,6 +754,7 @@ pub struct FileDecoder { dictionaries: HashMap, version: MetadataVersion, projection: Option>, + require_alignment: bool, } impl FileDecoder { @@ -665,6 +765,7 @@ impl FileDecoder { version, dictionaries: Default::default(), projection: None, + require_alignment: false, } } @@ -674,6 +775,23 @@ impl FileDecoder { self } + /// Specifies whether or not array data in input buffers is required to be properly aligned. + /// + /// If `require_alignment` is true, this decoder will return an error if any array data in the + /// input `buf` is not properly aligned. + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct + /// [`arrow_data::ArrayData`]. + /// + /// If `require_alignment` is false (the default), this decoder will automatically allocate a + /// new aligned buffer and copy over the data if any array data in the input `buf` is not + /// properly aligned. (Properly aligned array data will remain zero-copy.) + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct + /// [`arrow_data::ArrayData`]. + pub fn with_require_alignment(mut self, require_alignment: bool) -> Self { + self.require_alignment = require_alignment; + self + } + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { let message = parse_message(buf)?; @@ -692,12 +810,13 @@ impl FileDecoder { match message.header_type() { crate::MessageHeader::DictionaryBatch => { let batch = message.header_as_dictionary_batch().unwrap(); - read_dictionary( + read_dictionary_impl( &buf.slice(block.metaDataLength() as _), batch, &self.schema, &mut self.dictionaries, &message.version(), + self.require_alignment, ) } t => Err(ArrowError::ParseError(format!( @@ -722,13 +841,14 @@ impl FileDecoder { ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) })?; // read the block that makes up the record batch into a buffer - read_record_batch( + read_record_batch_impl( &buf.slice(block.metaDataLength() as _), batch, self.schema.clone(), &self.dictionaries, self.projection.as_deref(), &message.version(), + self.require_alignment, ) .map(Some) } @@ -1164,13 +1284,14 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_record_batch( + read_record_batch_impl( &buf.into(), batch, self.schema(), &self.dictionaries_by_id, self.projection.as_ref().map(|x| x.0.as_ref()), &message.version(), + false, ) .map(Some) } @@ -1184,12 +1305,13 @@ impl StreamReader { let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize); self.reader.read_exact(&mut buf)?; - read_dictionary( + read_dictionary_impl( &buf.into(), batch, &self.schema, &mut self.dictionaries_by_id, &message.version(), + false, )?; // read the next message until we encounter a RecordBatch @@ -1955,18 +2077,61 @@ mod tests { assert_ne!(b.as_ptr().align_offset(8), 0); let ipc_batch = message.header_as_record_batch().unwrap(); - let roundtrip = read_record_batch( + let roundtrip = read_record_batch_impl( &b, ipc_batch, batch.schema(), &Default::default(), None, &message.version(), + false, ) .unwrap(); assert_eq!(batch, roundtrip); } + #[test] + fn test_unaligned_throws_error_with_require_alignment() { + let batch = RecordBatch::try_from_iter(vec![( + "i32", + Arc::new(Int32Array::from(vec![1, 2, 3, 4])) as _, + )]) + .unwrap(); + + let gen = IpcDataGenerator {}; + let mut dict_tracker = DictionaryTracker::new(false); + let (_, encoded) = gen + .encoded_batch(&batch, &mut dict_tracker, &Default::default()) + .unwrap(); + + let message = root_as_message(&encoded.ipc_message).unwrap(); + + // Construct an unaligned buffer + let mut buffer = MutableBuffer::with_capacity(encoded.arrow_data.len() + 1); + buffer.push(0_u8); + buffer.extend_from_slice(&encoded.arrow_data); + let b = Buffer::from(buffer).slice(1); + assert_ne!(b.as_ptr().align_offset(8), 0); + + let ipc_batch = message.header_as_record_batch().unwrap(); + let result = read_record_batch_impl( + &b, + ipc_batch, + batch.schema(), + &Default::default(), + None, + &message.version(), + true, + ); + + let error = result.unwrap_err(); + assert_eq!( + error.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Int32, \ + offset from expected alignment of 4 by 1" + ); + } + #[test] fn test_file_with_massive_column_count() { // 499_999 is upper limit for default settings (1_000_000) diff --git a/arrow-ipc/src/reader/stream.rs b/arrow-ipc/src/reader/stream.rs index 7807228175ac..64191a22b33e 100644 --- a/arrow-ipc/src/reader/stream.rs +++ b/arrow-ipc/src/reader/stream.rs @@ -24,7 +24,7 @@ use arrow_buffer::{Buffer, MutableBuffer}; use arrow_schema::{ArrowError, SchemaRef}; use crate::convert::MessageBuffer; -use crate::reader::{read_dictionary, read_record_batch}; +use crate::reader::{read_dictionary_impl, read_record_batch_impl}; use crate::{MessageHeader, CONTINUATION_MARKER}; /// A low-level interface for reading [`RecordBatch`] data from a stream of bytes @@ -40,6 +40,8 @@ pub struct StreamDecoder { state: DecoderState, /// A scratch buffer when a read is split across multiple `Buffer` buf: MutableBuffer, + /// Whether or not array data in input buffers are required to be aligned + require_alignment: bool, } #[derive(Debug)] @@ -83,6 +85,23 @@ impl StreamDecoder { Self::default() } + /// Specifies whether or not array data in input buffers is required to be properly aligned. + /// + /// If `require_alignment` is true, this decoder will return an error if any array data in the + /// input `buf` is not properly aligned. + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build`] to construct + /// [`arrow_data::ArrayData`]. + /// + /// If `require_alignment` is false (the default), this decoder will automatically allocate a + /// new aligned buffer and copy over the data if any array data in the input `buf` is not + /// properly aligned. (Properly aligned array data will remain zero-copy.) + /// Under the hood it will use [`arrow_data::ArrayDataBuilder::build_aligned`] to construct + /// [`arrow_data::ArrayData`]. + pub fn with_require_alignment(mut self, require_alignment: bool) -> Self { + self.require_alignment = require_alignment; + self + } + /// Try to read the next [`RecordBatch`] from the provided [`Buffer`] /// /// [`Buffer::advance`] will be called on `buffer` for any consumed bytes. @@ -192,13 +211,14 @@ impl StreamDecoder { let schema = self.schema.clone().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - let batch = read_record_batch( + let batch = read_record_batch_impl( &body, batch, schema, &self.dictionaries, None, &version, + self.require_alignment, )?; self.state = DecoderState::default(); return Ok(Some(batch)); @@ -208,12 +228,13 @@ impl StreamDecoder { let schema = self.schema.as_deref().ok_or_else(|| { ArrowError::IpcError("Missing schema".to_string()) })?; - read_dictionary( + read_dictionary_impl( &body, dictionary, schema, &mut self.dictionaries, &version, + self.require_alignment, )?; self.state = DecoderState::default(); } diff --git a/arrow-ipc/src/writer.rs b/arrow-ipc/src/writer.rs index 2a3474fe0fc6..97136bd97c2f 100644 --- a/arrow-ipc/src/writer.rs +++ b/arrow-ipc/src/writer.rs @@ -43,8 +43,8 @@ use crate::CONTINUATION_MARKER; #[derive(Debug, Clone)] pub struct IpcWriteOptions { /// Write padding after memory buffers to this multiple of bytes. - /// Generally 8 or 64, defaults to 64 - alignment: usize, + /// Must be 8, 16, 32, or 64 - defaults to 64. + alignment: u8, /// The legacy format is for releases before 0.15.0, and uses metadata V4 write_legacy_ipc_format: bool, /// The metadata version to write. The Rust IPC writer supports V4+ @@ -87,11 +87,14 @@ impl IpcWriteOptions { write_legacy_ipc_format: bool, metadata_version: crate::MetadataVersion, ) -> Result { - if alignment == 0 || alignment % 8 != 0 { + let is_alignment_valid = + alignment == 8 || alignment == 16 || alignment == 32 || alignment == 64; + if !is_alignment_valid { return Err(ArrowError::InvalidArgumentError( - "Alignment should be greater than 0 and be a multiple of 8".to_string(), + "Alignment should be 8, 16, 32, or 64.".to_string(), )); } + let alignment: u8 = u8::try_from(alignment).expect("range already checked"); match metadata_version { crate::MetadataVersion::V1 | crate::MetadataVersion::V2 @@ -432,8 +435,8 @@ impl IpcDataGenerator { } // pad the tail of body data let len = arrow_data.len(); - let pad_len = pad_to_8(len as u32); - arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); + let pad_len = pad_to_alignment(write_options.alignment, len); + arrow_data.extend_from_slice(&PADDING[..pad_len]); // write data let buffers = fbb.create_vector(&buffers); @@ -520,8 +523,8 @@ impl IpcDataGenerator { // pad the tail of body data let len = arrow_data.len(); - let pad_len = pad_to_8(len as u32); - arrow_data.extend_from_slice(&vec![0u8; pad_len][..]); + let pad_len = pad_to_alignment(write_options.alignment, len); + arrow_data.extend_from_slice(&PADDING[..pad_len]); // write data let buffers = fbb.create_vector(&buffers); @@ -760,11 +763,11 @@ impl FileWriter { ) -> Result { let data_gen = IpcDataGenerator::default(); let mut writer = BufWriter::new(writer); - // write magic to header aligned on 8 byte boundary - let header_size = super::ARROW_MAGIC.len() + 2; - assert_eq!(header_size, 8); - writer.write_all(&super::ARROW_MAGIC[..])?; - writer.write_all(&[0, 0])?; + // write magic to header aligned on alignment boundary + let pad_len = pad_to_alignment(write_options.alignment, super::ARROW_MAGIC.len()); + let header_size = super::ARROW_MAGIC.len() + pad_len; + writer.write_all(&super::ARROW_MAGIC)?; + writer.write_all(&PADDING[..pad_len])?; // write the schema, set the written bytes to the schema + header let encoded_message = data_gen.schema_to_bytes(schema, &write_options); let (meta, data) = write_message(&mut writer, encoded_message, &write_options)?; @@ -1061,13 +1064,13 @@ pub fn write_message( write_options: &IpcWriteOptions, ) -> Result<(usize, usize), ArrowError> { let arrow_data_len = encoded.arrow_data.len(); - if arrow_data_len % 8 != 0 { + if arrow_data_len % usize::from(write_options.alignment) != 0 { return Err(ArrowError::MemoryError( "Arrow data not aligned".to_string(), )); } - let a = write_options.alignment - 1; + let a = usize::from(write_options.alignment - 1); let buffer = encoded.ipc_message; let flatbuf_size = buffer.len(); let prefix_size = if write_options.write_legacy_ipc_format { @@ -1089,11 +1092,11 @@ pub fn write_message( writer.write_all(&buffer)?; } // write padding - writer.write_all(&vec![0; padding_bytes])?; + writer.write_all(&PADDING[..padding_bytes])?; // write arrow data let body_len = if arrow_data_len > 0 { - write_body_buffers(&mut writer, &encoded.arrow_data)? + write_body_buffers(&mut writer, &encoded.arrow_data, write_options.alignment)? } else { 0 }; @@ -1101,19 +1104,23 @@ pub fn write_message( Ok((aligned_size, body_len)) } -fn write_body_buffers(mut writer: W, data: &[u8]) -> Result { - let len = data.len() as u32; - let pad_len = pad_to_8(len) as u32; +fn write_body_buffers( + mut writer: W, + data: &[u8], + alignment: u8, +) -> Result { + let len = data.len(); + let pad_len = pad_to_alignment(alignment, len); let total_len = len + pad_len; // write body buffer writer.write_all(data)?; if pad_len > 0 { - writer.write_all(&vec![0u8; pad_len as usize][..])?; + writer.write_all(&PADDING[..pad_len])?; } writer.flush()?; - Ok(total_len as usize) + Ok(total_len) } /// Write a record batch to the writer, writing the message size before the message @@ -1278,6 +1285,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } @@ -1291,6 +1299,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } } else if matches!(data_type, DataType::BinaryView | DataType::Utf8View) { @@ -1307,6 +1316,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } } else if matches!(data_type, DataType::LargeBinary | DataType::LargeUtf8) { @@ -1318,6 +1328,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; } } else if DataType::is_numeric(data_type) @@ -1343,7 +1354,14 @@ fn write_array_data( } else { buffer.as_slice() }; - offset = write_buffer(buffer_slice, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + buffer_slice, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } else if matches!(data_type, DataType::Boolean) { // Bools are special because the payload (= 1 bit) is smaller than the physical container elements (= bytes). // The array data may not start at the physical boundary of the underlying buffer, so we need to shift bits around. @@ -1351,7 +1369,14 @@ fn write_array_data( let buffer = &array_data.buffers()[0]; let buffer = buffer.bit_slice(array_data.offset(), array_data.len()); - offset = write_buffer(&buffer, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + &buffer, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } else if matches!( data_type, DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) @@ -1372,6 +1397,7 @@ fn write_array_data( arrow_data, offset, compression_codec, + write_options.alignment, )?; offset = write_array_data( &sliced_child_data, @@ -1387,7 +1413,14 @@ fn write_array_data( return Ok(offset); } else { for buffer in array_data.buffers() { - offset = write_buffer(buffer, buffers, arrow_data, offset, compression_codec)?; + offset = write_buffer( + buffer, + buffers, + arrow_data, + offset, + compression_codec, + write_options.alignment, + )?; } } @@ -1451,6 +1484,7 @@ fn write_buffer( arrow_data: &mut Vec, // output stream offset: i64, // current output stream offset compression_codec: Option, + alignment: u8, ) -> Result { let len: i64 = match compression_codec { Some(compressor) => compressor.compress_to_vec(buffer, arrow_data)?, @@ -1466,17 +1500,20 @@ fn write_buffer( // make new index entry buffers.push(crate::Buffer::new(offset, len)); - // padding and make offset 8 bytes aligned - let pad_len = pad_to_8(len as u32) as i64; - arrow_data.extend_from_slice(&vec![0u8; pad_len as usize][..]); + // padding and make offset aligned + let pad_len = pad_to_alignment(alignment, len as usize); + arrow_data.extend_from_slice(&PADDING[..pad_len]); - Ok(offset + len + pad_len) + Ok(offset + len + (pad_len as i64)) } -/// Calculate an 8-byte boundary and return the number of bytes needed to pad to 8 bytes +const PADDING: [u8; 64] = [0; 64]; + +/// Calculate an alignment boundary and return the number of bytes needed to pad to the alignment boundary #[inline] -fn pad_to_8(len: u32) -> usize { - (((len + 7) & !7) - len) as usize +fn pad_to_alignment(alignment: u8, len: usize) -> usize { + let a = usize::from(alignment - 1); + ((len + a) & !a) - len } #[cfg(test)] @@ -1490,7 +1527,9 @@ mod tests { use arrow_array::builder::{PrimitiveRunBuilder, UInt32Builder}; use arrow_array::types::*; + use crate::convert::fb_to_schema; use crate::reader::*; + use crate::root_as_footer; use crate::MetadataVersion; use super::*; @@ -1508,7 +1547,17 @@ mod tests { } fn serialize_stream(record: &RecordBatch) -> Vec { - let mut stream_writer = StreamWriter::try_new(vec![], record.schema_ref()).unwrap(); + // Use 8-byte alignment so that the various `truncate_*` tests can be compactly written, + // without needing to construct a giant array to spill over the 64-byte default alignment + // boundary. + const IPC_ALIGNMENT: usize = 8; + + let mut stream_writer = StreamWriter::try_new_with_options( + vec![], + record.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); stream_writer.write(record).unwrap(); stream_writer.finish().unwrap(); stream_writer.into_inner().unwrap() @@ -2345,4 +2394,120 @@ mod tests { let in_batch = RecordBatch::try_new(schema, vec![values]).unwrap(); roundtrip_ensure_sliced_smaller(in_batch, 1000); } + + #[test] + fn test_decimal128_alignment16_is_sufficient() { + const IPC_ALIGNMENT: usize = 16; + + // Test a bunch of different dimensions to ensure alignment is never an issue. + // For example, if we only test `num_cols = 1` then even with alignment 8 this + // test would _happen_ to pass, even though for different dimensions like + // `num_cols = 2` it would fail. + for num_cols in [1, 2, 3, 17, 50, 73, 99] { + let num_rows = (num_cols * 7 + 11) % 100; // Deterministic swizzle + + let mut fields = Vec::new(); + let mut arrays = Vec::new(); + for i in 0..num_cols { + let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); + let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); + fields.push(field); + arrays.push(Arc::new(array) as Arc); + } + let schema = Schema::new(fields); + let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); + + let mut writer = FileWriter::try_new_with_options( + Vec::new(), + batch.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + let out: Vec = writer.into_inner().unwrap(); + + let buffer = Buffer::from_vec(out); + let trailer_start = buffer.len() - 10; + let footer_len = + read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); + let footer = + root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); + + let schema = fb_to_schema(footer.schema().unwrap()); + + // Importantly we set `require_alignment`, checking that 16-byte alignment is sufficient + // for `read_record_batch` later on to read the data in a zero-copy manner. + let decoder = + FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); + + let batches = footer.recordBatches().unwrap(); + + let block = batches.get(0); + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + + let batch2 = decoder.read_record_batch(block, &data).unwrap().unwrap(); + + assert_eq!(batch, batch2); + } + } + + #[test] + fn test_decimal128_alignment8_is_unaligned() { + const IPC_ALIGNMENT: usize = 8; + + let num_cols = 2; + let num_rows = 1; + + let mut fields = Vec::new(); + let mut arrays = Vec::new(); + for i in 0..num_cols { + let field = Field::new(&format!("col_{}", i), DataType::Decimal128(38, 10), true); + let array = Decimal128Array::from(vec![num_cols as i128; num_rows]); + fields.push(field); + arrays.push(Arc::new(array) as Arc); + } + let schema = Schema::new(fields); + let batch = RecordBatch::try_new(Arc::new(schema), arrays).unwrap(); + + let mut writer = FileWriter::try_new_with_options( + Vec::new(), + batch.schema_ref(), + IpcWriteOptions::try_new(IPC_ALIGNMENT, false, MetadataVersion::V5).unwrap(), + ) + .unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + let out: Vec = writer.into_inner().unwrap(); + + let buffer = Buffer::from_vec(out); + let trailer_start = buffer.len() - 10; + let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); + let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); + + let schema = fb_to_schema(footer.schema().unwrap()); + + // Importantly we set `require_alignment`, otherwise the error later is suppressed due to copying + // to an aligned buffer in `ArrayDataBuilder.build_aligned`. + let decoder = + FileDecoder::new(Arc::new(schema), footer.version()).with_require_alignment(true); + + let batches = footer.recordBatches().unwrap(); + + let block = batches.get(0); + let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; + let data = buffer.slice_with_length(block.offset() as _, block_len); + + let result = decoder.read_record_batch(block, &data); + + let error = result.unwrap_err(); + assert_eq!( + error.to_string(), + "Invalid argument error: Misaligned buffers[0] in array of type Decimal128(38, 10), \ + offset from expected alignment of 16 by 8" + ); + } }