From 98634862994f88f5bebb4ef34cb49dcf6361c95a Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Sat, 30 Dec 2023 06:46:17 +0000 Subject: [PATCH] Add IPC FileDecoder (#5249) * Add IPC FileDecoder * Clippy * Update arrow-ipc/src/reader.rs Co-authored-by: Liang-Chi Hsieh --------- Co-authored-by: Liang-Chi Hsieh --- arrow-ipc/src/reader.rs | 283 ++++++++++++++++++++++++++-------------- 1 file changed, 186 insertions(+), 97 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 39365b782944..8ac3a387d5bb 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -451,7 +451,7 @@ pub fn read_dictionary( batch: crate::DictionaryBatch, schema: &Schema, dictionaries_by_id: &mut HashMap, - metadata: &crate::MetadataVersion, + metadata: &MetadataVersion, ) -> Result<(), ArrowError> { if batch.isDelta() { return Err(ArrowError::InvalidArgumentError( @@ -522,6 +522,174 @@ fn parse_message(buf: &[u8]) -> Result { .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}"))) } +/// Read the footer length from the last 10 bytes of an Arrow IPC file +/// +/// Expects a 4 byte footer length followed by `b"ARROW1"` +pub fn read_footer_length(buf: [u8; 10]) -> Result { + if buf[4..] != super::ARROW_MAGIC { + return Err(ArrowError::ParseError( + "Arrow file does not contain correct footer".to_string(), + )); + } + + // read footer length + let footer_len = i32::from_le_bytes(buf[..4].try_into().unwrap()); + footer_len + .try_into() + .map_err(|_| ArrowError::ParseError(format!("Invalid footer length: {footer_len}"))) +} + +/// A low-level, push-based interface for reading an IPC file +/// +/// For a higher-level interface see [`FileReader`] +/// +/// ``` +/// # use std::sync::Arc; +/// # use arrow_array::*; +/// # use arrow_array::types::Int32Type; +/// # use arrow_buffer::Buffer; +/// # use arrow_ipc::convert::fb_to_schema; +/// # use arrow_ipc::reader::{FileDecoder, read_footer_length}; +/// # use arrow_ipc::root_as_footer; +/// # use arrow_ipc::writer::FileWriter; +/// // Write an IPC file +/// +/// let batch = RecordBatch::try_from_iter([ +/// ("a", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), +/// ("b", Arc::new(Int32Array::from(vec![1, 2, 3])) as _), +/// ("c", Arc::new(DictionaryArray::::from_iter(["hello", "hello", "world"])) as _), +/// ]).unwrap(); +/// +/// let schema = batch.schema(); +/// +/// let mut out = Vec::with_capacity(1024); +/// let mut writer = FileWriter::try_new(&mut out, schema.as_ref()).unwrap(); +/// writer.write(&batch).unwrap(); +/// writer.finish().unwrap(); +/// +/// drop(writer); +/// +/// // Read IPC file +/// +/// let buffer = Buffer::from_vec(out); +/// let trailer_start = buffer.len() - 10; +/// let footer_len = read_footer_length(buffer[trailer_start..].try_into().unwrap()).unwrap(); +/// let footer = root_as_footer(&buffer[trailer_start - footer_len..trailer_start]).unwrap(); +/// +/// let back = fb_to_schema(footer.schema().unwrap()); +/// assert_eq!(&back, schema.as_ref()); +/// +/// let mut decoder = FileDecoder::new(schema, footer.version()); +/// +/// // Read dictionaries +/// for block in footer.dictionaries().iter().flatten() { +/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; +/// let data = buffer.slice_with_length(block.offset() as _, block_len); +/// decoder.read_dictionary(&block, &data).unwrap(); +/// } +/// +/// // Read record batch +/// let batches = footer.recordBatches().unwrap(); +/// assert_eq!(batches.len(), 1); // Only wrote a single batch +/// +/// let block = batches.get(0); +/// let block_len = block.bodyLength() as usize + block.metaDataLength() as usize; +/// let data = buffer.slice_with_length(block.offset() as _, block_len); +/// let back = decoder.read_record_batch(block, &data).unwrap().unwrap(); +/// +/// assert_eq!(batch, back); +/// ``` +#[derive(Debug)] +pub struct FileDecoder { + schema: SchemaRef, + dictionaries: HashMap, + version: MetadataVersion, + projection: Option>, +} + +impl FileDecoder { + /// Create a new [`FileDecoder`] with the given schema and version + pub fn new(schema: SchemaRef, version: MetadataVersion) -> Self { + Self { + schema, + version, + dictionaries: Default::default(), + projection: None, + } + } + + /// Specify a projection + pub fn with_projection(mut self, projection: Vec) -> Self { + self.projection = Some(projection); + self + } + + fn read_message<'a>(&self, buf: &'a [u8]) -> Result, ArrowError> { + let message = parse_message(buf)?; + + // some old test data's footer metadata is not set, so we account for that + if self.version != MetadataVersion::V1 && message.version() != self.version { + return Err(ArrowError::IpcError( + "Could not read IPC message as metadata versions mismatch".to_string(), + )); + } + Ok(message) + } + + /// Read the dictionary with the given block and data buffer + pub fn read_dictionary(&mut self, block: &Block, buf: &Buffer) -> Result<(), ArrowError> { + let message = self.read_message(buf)?; + match message.header_type() { + crate::MessageHeader::DictionaryBatch => { + let batch = message.header_as_dictionary_batch().unwrap(); + read_dictionary( + &buf.slice(block.metaDataLength() as _), + batch, + &self.schema, + &mut self.dictionaries, + &message.version(), + ) + } + t => Err(ArrowError::ParseError(format!( + "Expecting DictionaryBatch in dictionary blocks, found {t:?}." + ))), + } + } + + /// Read the RecordBatch with the given block and data buffer + pub fn read_record_batch( + &self, + block: &Block, + buf: &Buffer, + ) -> Result, ArrowError> { + let message = self.read_message(buf)?; + match message.header_type() { + crate::MessageHeader::Schema => Err(ArrowError::IpcError( + "Not expecting a schema when messages are read".to_string(), + )), + crate::MessageHeader::RecordBatch => { + let batch = message.header_as_record_batch().ok_or_else(|| { + ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) + })?; + // read the block that makes up the record batch into a buffer + read_record_batch( + &buf.slice(block.metaDataLength() as _), + batch, + self.schema.clone(), + &self.dictionaries, + self.projection.as_deref(), + &message.version(), + ) + .map(Some) + } + crate::MessageHeader::NONE => Ok(None), + t => Err(ArrowError::InvalidArgumentError(format!( + "Reading types other than record batches not yet supported, unable to read {t:?}" + ))), + } + } +} + /// Build an Arrow [`FileReader`] with custom options. #[derive(Debug)] pub struct FileReaderBuilder { @@ -599,17 +767,10 @@ impl FileReaderBuilder { reader.seek(SeekFrom::End(-10))?; reader.read_exact(&mut buffer)?; - if buffer[4..] != super::ARROW_MAGIC { - return Err(ArrowError::ParseError( - "Arrow file does not contain correct footer".to_string(), - )); - } - - // read footer length - let footer_len = i32::from_le_bytes(buffer[..4].try_into().unwrap()); + let footer_len = read_footer_length(buffer)?; // read footer - let mut footer_data = vec![0; footer_len as usize]; + let mut footer_data = vec![0; footer_len]; reader.seek(SeekFrom::End(-10 - footer_len as i64))?; reader.read_exact(&mut footer_data)?; @@ -641,50 +802,26 @@ impl FileReaderBuilder { } } + let mut decoder = FileDecoder::new(Arc::new(schema), footer.version()); + if let Some(projection) = self.projection { + decoder = decoder.with_projection(projection) + } + // Create an array of optional dictionary value arrays, one per field. - let mut dictionaries_by_id = HashMap::new(); if let Some(dictionaries) = footer.dictionaries() { for block in dictionaries { let buf = read_block(&mut reader, block)?; - let message = parse_message(&buf)?; - - match message.header_type() { - crate::MessageHeader::DictionaryBatch => { - let batch = message.header_as_dictionary_batch().unwrap(); - read_dictionary( - &buf.slice(block.metaDataLength() as _), - batch, - &schema, - &mut dictionaries_by_id, - &message.version(), - )?; - } - t => { - return Err(ArrowError::ParseError(format!( - "Expecting DictionaryBatch in dictionary blocks, found {t:?}." - ))); - } - } + decoder.read_dictionary(block, &buf)?; } } - let projection = match self.projection { - Some(projection_indices) => { - let schema = schema.project(&projection_indices)?; - Some((projection_indices, schema)) - } - _ => None, - }; Ok(FileReader { reader, - schema: Arc::new(schema), blocks: blocks.iter().copied().collect(), current_block: 0, total_blocks, - dictionaries_by_id, - metadata_version: footer.version(), + decoder, custom_metadata, - projection, }) } } @@ -694,13 +831,13 @@ pub struct FileReader { /// Buffered file reader that supports reading and seeking reader: R, - /// The schema that is read from the file header - schema: SchemaRef, + /// The decoder + decoder: FileDecoder, /// The blocks in the file /// /// A block indicates the regions in the file to read to get data - blocks: Vec, + blocks: Vec, /// A counter to keep track of the current block that should be read current_block: usize, @@ -708,31 +845,17 @@ pub struct FileReader { /// The total number of blocks, which may contain record batches and other types total_blocks: usize, - /// Optional dictionaries for each schema field. - /// - /// Dictionaries may be appended to in the streaming format. - dictionaries_by_id: HashMap, - - /// Metadata version - metadata_version: crate::MetadataVersion, - /// User defined metadata custom_metadata: HashMap, - - /// Optional projection and projected_schema - projection: Option<(Vec, Schema)>, } impl fmt::Debug for FileReader { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { f.debug_struct("FileReader") - .field("schema", &self.schema) + .field("decoder", &self.decoder) .field("blocks", &self.blocks) .field("current_block", &self.current_block) .field("total_blocks", &self.total_blocks) - .field("dictionaries_by_id", &self.dictionaries_by_id) - .field("metadata_version", &self.metadata_version) - .field("projection", &self.projection) .finish_non_exhaustive() } } @@ -761,7 +884,7 @@ impl FileReader { /// Return the schema of the file pub fn schema(&self) -> SchemaRef { - self.schema.clone() + self.decoder.schema.clone() } /// Read a specific record batch @@ -785,41 +908,7 @@ impl FileReader { // read length let buffer = read_block(&mut self.reader, block)?; - let message = parse_message(&buffer)?; - - // some old test data's footer metadata is not set, so we account for that - if self.metadata_version != MetadataVersion::V1 - && message.version() != self.metadata_version - { - return Err(ArrowError::IpcError( - "Could not read IPC message as metadata versions mismatch".to_string(), - )); - } - - match message.header_type() { - crate::MessageHeader::Schema => Err(ArrowError::IpcError( - "Not expecting a schema when messages are read".to_string(), - )), - crate::MessageHeader::RecordBatch => { - let batch = message.header_as_record_batch().ok_or_else(|| { - ArrowError::IpcError("Unable to read IPC message as record batch".to_string()) - })?; - // read the block that makes up the record batch into a buffer - read_record_batch( - &buffer.slice(block.metaDataLength() as _), - batch, - self.schema(), - &self.dictionaries_by_id, - self.projection.as_ref().map(|x| x.0.as_ref()), - &message.version(), - ) - .map(Some) - } - crate::MessageHeader::NONE => Ok(None), - t => Err(ArrowError::InvalidArgumentError(format!( - "Reading types other than record batches not yet supported, unable to read {t:?}" - ))), - } + self.decoder.read_record_batch(block, &buffer) } /// Gets a reference to the underlying reader. @@ -852,7 +941,7 @@ impl Iterator for FileReader { impl RecordBatchReader for FileReader { fn schema(&self) -> SchemaRef { - self.schema.clone() + self.schema() } }