Skip to content

Commit

Permalink
This uses a struct wrapping Buffer and implementing
Browse files Browse the repository at this point in the history
AsRef<[u8]> instead of a breaking trait.
  • Loading branch information
Sven Cattell committed Dec 6, 2023
1 parent 5bae721 commit cbb0c52
Showing 1 changed file with 78 additions and 101 deletions.
179 changes: 78 additions & 101 deletions arrow-ipc/src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
use flatbuffers::VectorIter;
use std::collections::HashMap;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::fmt;
use std::io::{BufReader, Cursor, Read, Seek, SeekFrom};
use std::sync::Arc;
use std::{fmt, io};

use arrow_array::*;
use arrow_buffer::{Buffer, MutableBuffer};
Expand Down Expand Up @@ -498,67 +498,24 @@ pub fn read_dictionary(
Ok(())
}

pub trait BufferRead {
fn read_buffer(&mut self, len: usize) -> Result<Buffer, ArrowError>;
}

impl<R: Read + Seek> BufferRead for BufReader<R> {
fn read_buffer(&mut self, len: usize) -> Result<Buffer, ArrowError> {
let mut buf = MutableBuffer::from_len_zeroed(len);
self.read_exact(&mut buf)?;
Ok(buf.into())
}
}

#[derive(Clone, Debug)]
struct BufferReader {
buffer: Buffer,
position: u64,
inter: Buffer,
}

impl Seek for BufferReader {
fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
let (base_pos, offset) = match style {
SeekFrom::Start(n) => {
self.position = n;
return Ok(n);
}
SeekFrom::End(n) => (self.buffer.as_ref().len() as u64, n),
SeekFrom::Current(n) => (self.position, n),
};
match base_pos.checked_add_signed(offset) {
Some(n) => {
self.position = n;
Ok(self.position)
}
None => Err(io::Error::new(
std::io::ErrorKind::InvalidInput,
"invalid seek to a negative or overflowing position",
)),
}
}
}

impl Read for BufferReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut remaining_slice = &self.buffer[self.position as usize..];
let n = Read::read(&mut remaining_slice, buf)?;
self.position += n as u64;
Ok(n)
}
}

impl BufferRead for BufferReader {
fn read_buffer(&mut self, len: usize) -> Result<Buffer, ArrowError> {
let buf = self.buffer.slice_with_length(self.position as usize, len);
self.position += len as u64;
Ok(buf)
impl AsRef<[u8]> for BufferReader {
fn as_ref(&self) -> &[u8] {
&self.inter
}
}

/// Arrow File reader
pub struct FileReader<R: BufferRead> {
pub struct FileReader<R: Read + Seek> {
/// Buffered file reader that supports reading and seeking
reader: R,
reader: BufReader<R>,

/// Optional Buffer for when we want to map the underlying data without copying
buffer: Option<BufferReader>,

/// The schema that is read from the file header
schema: SchemaRef,
Expand Down Expand Up @@ -589,7 +546,7 @@ pub struct FileReader<R: BufferRead> {
projection: Option<(Vec<usize>, Schema)>,
}

impl<R: BufferRead> fmt::Debug for FileReader<R> {
impl<R: Read + Seek> fmt::Debug for FileReader<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::result::Result<(), fmt::Error> {
f.debug_struct("FileReader<R>")
.field("reader", &"BufReader<..>")
Expand All @@ -605,52 +562,35 @@ impl<R: BufferRead> fmt::Debug for FileReader<R> {
}
}

impl<R: Read + Seek> FileReader<BufReader<R>> {
impl FileReader<Cursor<BufferReader>> {
pub fn try_new_map(buffer: Buffer, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
let buffer = BufferReader { inter: buffer };
let cursor = Cursor::new(buffer.clone());
Self::try_new_inter(cursor, projection, Some(buffer))
}
}

impl<R: Read + Seek> FileReader<R> {
/// Try to create a new file reader
///
/// Returns errors if the file does not meet the Arrow Format header and footer
/// requirements
pub fn try_new(reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
Self::try_new_inter(BufReader::new(reader), projection)
}

/// Gets a reference to the underlying reader.
///
/// It is inadvisable to directly read from the underlying reader.
pub fn get_ref(&self) -> &R {
self.reader.get_ref()
}

/// Gets a mutable reference to the underlying reader.
///
/// It is inadvisable to directly read from the underlying reader.
pub fn get_mut(&mut self) -> &mut R {
self.reader.get_mut()
}
}

impl FileReader<BufferReader> {
pub fn try_new_map(buffer: Buffer, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
let reader = BufferReader {
buffer,
position: 0,
};
Self::try_new_inter(reader, projection)
Self::try_new_inter(reader, projection, None)
}

pub fn get_buf(&self) -> &Buffer {
&self.reader.buffer
}
}

impl<R: Read + Seek + BufferRead> FileReader<R> {
/// Try to create a new file reader
///
/// Returns errors if the file does not meet the Arrow Format header and footer
/// requirements
///
/// If a buffer is provided, then the reader needs to be a io::Cursor wrapped around the buffer
fn try_new_inter(mut reader: R, projection: Option<Vec<usize>>) -> Result<Self, ArrowError> {
fn try_new_inter(
reader: R,
projection: Option<Vec<usize>>,
buffer: Option<BufferReader>,
) -> Result<Self, ArrowError> {
let mut reader = BufReader::new(reader);
// check if header and footer contain correct magic bytes
// check if header and footer contain correct magic bytes
let mut magic_buffer: [u8; 6] = [0; 6];
Expand Down Expand Up @@ -723,12 +663,24 @@ impl<R: Read + Seek + BufferRead> FileReader<R> {

match message.header_type() {
crate::MessageHeader::DictionaryBatch => {
// read the block that makes up the dictionary batch into a buffer
let batch = message.header_as_dictionary_batch().unwrap();
reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
let buf = reader.read_buffer(message.bodyLength() as usize)?;

let buf = if let Some(buffer) = &buffer {
// Don't need to check compression as dictionaries aren't compressed.
buffer.inter.slice_with_length(
block.offset() as usize + block.metaDataLength() as usize,
message.bodyLength() as usize,
)
} else {
// read the block that makes up the dictionary batch into a buffer
let mut buf =
MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
reader.read_exact(&mut buf)?;
buf.into()
};

read_dictionary(
&buf,
Expand Down Expand Up @@ -756,6 +708,7 @@ impl<R: Read + Seek + BufferRead> FileReader<R> {

Ok(Self {
reader,
buffer,
schema: Arc::new(schema),
blocks: blocks.iter().copied().collect(),
current_block: 0,
Expand Down Expand Up @@ -835,11 +788,21 @@ impl<R: Read + Seek + BufferRead> FileReader<R> {
ArrowError::IpcError("Unable to read IPC message as record batch".to_string())
})?;

// read the block that makes up the dictionary batch into a buffer
self.reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
let buf = self.reader.read_buffer(message.bodyLength() as usize)?;
let buf = if let Some(buffer) = &self.buffer {
// Don't need to check compression as if it is compressed it will be read into memory.
buffer.inter.slice_with_length(
block.offset() as usize + block.metaDataLength() as usize,
message.bodyLength() as usize,
)
} else {
// read the block that makes up the batch into a buffer
let mut buf = MutableBuffer::from_len_zeroed(message.bodyLength() as usize);
self.reader.seek(SeekFrom::Start(
block.offset() as u64 + block.metaDataLength() as u64,
))?;
self.reader.read_exact(&mut buf)?;
buf.into()
};

read_record_batch(
&buf,
Expand All @@ -857,9 +820,23 @@ impl<R: Read + Seek + BufferRead> FileReader<R> {
))),
}
}

/// Gets a reference to the underlying reader.
///
/// It is inadvisable to directly read from the underlying reader.
pub fn get_ref(&self) -> &R {
self.reader.get_ref()
}

/// Gets a mutable reference to the underlying reader.
///
/// It is inadvisable to directly read from the underlying reader.
pub fn get_mut(&mut self) -> &mut R {
self.reader.get_mut()
}
}

impl<R: Read + Seek + BufferRead> Iterator for FileReader<R> {
impl<R: Read + Seek> Iterator for FileReader<R> {
type Item = Result<RecordBatch, ArrowError>;

fn next(&mut self) -> Option<Self::Item> {
Expand All @@ -872,7 +849,7 @@ impl<R: Read + Seek + BufferRead> Iterator for FileReader<R> {
}
}

impl<R: Read + Seek + BufferRead> RecordBatchReader for FileReader<R> {
impl<R: Read + Seek> RecordBatchReader for FileReader<R> {
fn schema(&self) -> SchemaRef {
self.schema.clone()
}
Expand Down

0 comments on commit cbb0c52

Please sign in to comment.