Skip to content

Commit

Permalink
Add truncated rows to RecordDecoder::new
Browse files Browse the repository at this point in the history
Instead of using a setter, truncated rows is passwed into
the `new` method for RecordDecoder since it is not part
of the public API.
  • Loading branch information
Posnet committed May 7, 2024
1 parent e58a24d commit 2aa0bc9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
11 changes: 6 additions & 5 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,9 +1164,11 @@ impl ReaderBuilder {
/// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
pub fn build_decoder(self) -> Decoder {
let delimiter = self.format.build_parser();
let mut record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len());

record_decoder.set_truncated_rows(self.format.truncated_rows);
let record_decoder = RecordDecoder::new(
delimiter,
self.schema.fields().len(),
self.format.truncated_rows,
);

let header = self.format.header as usize;

Expand Down Expand Up @@ -2207,11 +2209,10 @@ mod tests {

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(batches.is_ok());
let batch = batches.unwrap().into_iter().nth(0).unwrap();
let batch = batches.unwrap().into_iter().next().unwrap();
// Empty rows are skipped by the underlying csv parser
assert_eq!(batch.num_rows(), 3);


let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_truncated_rows(false)
Expand Down
20 changes: 7 additions & 13 deletions arrow-csv/src/reader/records.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub struct RecordDecoder {
}

impl RecordDecoder {
pub fn new(delimiter: Reader, num_columns: usize) -> Self {
pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
Self {
delimiter,
num_columns,
Expand All @@ -76,7 +76,7 @@ impl RecordDecoder {
data_len: 0,
data: vec![],
num_rows: 0,
truncated_rows: false,
truncated_rows,
}
}

Expand Down Expand Up @@ -188,11 +188,6 @@ impl RecordDecoder {
self.num_rows = 0;
}

/// Sets the decoder to allow rows with less than the expected number columns
pub fn set_truncated_rows(&mut self, allow: bool) {
self.truncated_rows = allow;
}

/// Flushes the current contents of the reader
pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
if self.current_field != 0 {
Expand Down Expand Up @@ -320,7 +315,7 @@ mod tests {
.into_iter();

let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
let mut decoder = RecordDecoder::new(Reader::new(), 3);
let mut decoder = RecordDecoder::new(Reader::new(), 3, false);

loop {
let to_read = 3;
Expand Down Expand Up @@ -354,15 +349,15 @@ mod tests {
#[test]
fn test_invalid_fields() {
let csv = "a,b\nb,c\na\n";
let mut decoder = RecordDecoder::new(Reader::new(), 2);
let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();

let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";

assert_eq!(err, expected);

// Test with initial skip
let mut decoder = RecordDecoder::new(Reader::new(), 2);
let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
assert_eq!(skipped, 1);
decoder.clear();
Expand All @@ -375,7 +370,7 @@ mod tests {
#[test]
fn test_skip_insufficient_rows() {
let csv = "a\nv\n";
let mut decoder = RecordDecoder::new(Reader::new(), 1);
let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
assert_eq!(read, 2);
assert_eq!(bytes, csv.len());
Expand All @@ -384,8 +379,7 @@ mod tests {
#[test]
fn test_truncated_rows() {
let csv = "a,b\nv\n,1\n,2\n,3\n";
let mut decoder = RecordDecoder::new(Reader::new(), 2);
decoder.set_truncated_rows(true);
let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
assert_eq!(read, 5);
assert_eq!(bytes, csv.len());
Expand Down

0 comments on commit 2aa0bc9

Please sign in to comment.