Skip to content

Commit

Permalink
Add support for truncated rows
Browse files Browse the repository at this point in the history
Similar to what is supported in the csv crate, as well as the pandas,
arrow-cpp and polars crates. A subset of CSV files treat missing columns
at the end of rows as null (if the schema allows it). This commit adds
support to optionally enable treating such missing columns as null.
The default behavior is still to treat an incorrect number of columns
as an error.
  • Loading branch information
Posnet committed May 2, 2024
1 parent c5b3304 commit d807a4e
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 5 deletions.
152 changes: 151 additions & 1 deletion arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ pub struct Format {
quote: Option<u8>,
terminator: Option<u8>,
null_regex: NullRegex,
allow_truncated_rows: bool,
}

impl Format {
Expand Down Expand Up @@ -265,6 +266,17 @@ impl Format {
self
}

/// Whether to allow truncated rows when parsing.
///
/// By default this is set to `false` and will error if the CSV rows have different lengths.
/// When set to true then it will allow records with less than the expected number of columns
/// and fill the missing columns with nulls. If the record's schema is not nullable, then it
/// will still return an error.
pub fn with_allow_truncated_rows(mut self, allow: bool) -> Self {
self.allow_truncated_rows = allow;
self
}

/// Infer schema of CSV records from the provided `reader`
///
/// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
Expand Down Expand Up @@ -329,6 +341,7 @@ impl Format {
fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
let mut builder = csv::ReaderBuilder::new();
builder.has_headers(self.header);
builder.flexible(self.allow_truncated_rows);

if let Some(c) = self.delimiter {
builder.delimiter(c);
Expand Down Expand Up @@ -1117,6 +1130,17 @@ impl ReaderBuilder {
self
}

/// Whether to allow truncated rows when parsing.
///
/// By default this is set to `false` and will error if the CSV rows have different lengths.
/// When set to true then it will allow records with less than the expected number of columns
/// and fill the missing columns with nulls. If the record's schema is not nullable, then it
/// will still return an error.
pub fn with_allow_truncated_rows(mut self, allow: bool) -> Self {
self.format.allow_truncated_rows = allow;
self
}

/// Create a new `Reader` from a non-buffered reader
///
/// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
Expand All @@ -1136,7 +1160,9 @@ impl ReaderBuilder {
/// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
pub fn build_decoder(self) -> Decoder {
let delimiter = self.format.build_parser();
let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len());
let mut record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len());

record_decoder.set_allow_truncated_rows(self.format.allow_truncated_rows);

let header = self.format.header as usize;

Expand Down Expand Up @@ -2160,6 +2186,130 @@ mod tests {
assert!(c.is_null(3));
}

#[test]
fn test_allow_truncated_rows() {
let data = "a,b,c\n1,2,3\n4,5";
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
]));

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_allow_truncated_rows(true)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(batches.is_ok());

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_allow_truncated_rows(false)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(match batches {
Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
_ => false,
});
}

#[test]
fn test_allow_truncated_rows_csv() {
let file = File::open("test/data/truncated_rows.csv").unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("Name", DataType::Utf8, true),
Field::new("Age", DataType::UInt32, true),
Field::new("Occupation", DataType::Utf8, true),
Field::new("DOB", DataType::Date32, true),
]));
let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_batch_size(24)
.with_allow_truncated_rows(true);
let csv = reader.build(file).unwrap();
let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();

assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 6);
assert_eq!(batch.num_columns(), 4);
let name = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let age = batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let occupation = batch
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let dob = batch
.column(3)
.as_any()
.downcast_ref::<Date32Array>()
.unwrap();

assert_eq!(name.value(0), "A1");
assert_eq!(name.value(1), "B2");
assert!(name.is_null(2));
assert_eq!(name.value(3), "C3");
assert_eq!(name.value(4), "D4");
assert_eq!(name.value(5), "E5");

assert_eq!(age.value(0), 34);
assert_eq!(age.value(1), 29);
assert!(age.is_null(2));
assert_eq!(age.value(3), 45);
assert!(age.is_null(4));
assert_eq!(age.value(5), 31);

assert_eq!(occupation.value(0), "Engineer");
assert_eq!(occupation.value(1), "Doctor");
assert!(occupation.is_null(2));
assert_eq!(occupation.value(3), "Artist");
assert!(occupation.is_null(4));
assert!(occupation.is_null(5));

assert_eq!(dob.value(0), 5675);
assert!(dob.is_null(1));
assert!(dob.is_null(2));
assert_eq!(dob.value(3), -1858);
assert!(dob.is_null(4));
assert!(dob.is_null(5));
}

#[test]
fn test_allow_truncated_rows_not_nullable_error() {
let data = "a,b,c\n1,2,3\n4,5";
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_allow_truncated_rows(true)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(match batches {
Err(ArrowError::InvalidArgumentError(e)) =>
e.to_string().contains("contains null values"),
_ => false,
});
}

#[test]
fn test_buffered() {
let tests = [
Expand Down
39 changes: 35 additions & 4 deletions arrow-csv/src/reader/records.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ pub struct RecordDecoder {
///
/// We track this independently of Vec to avoid re-zeroing memory
data_len: usize,

/// Whether rows with less than expected columns are considered valid
///
/// Default value is false
/// When enabled fills in missing columns with null
allow_truncated_rows: bool,
}

impl RecordDecoder {
Expand All @@ -70,6 +76,7 @@ impl RecordDecoder {
data_len: 0,
data: vec![],
num_rows: 0,
allow_truncated_rows: false,
}
}

Expand Down Expand Up @@ -127,10 +134,19 @@ impl RecordDecoder {
}
ReadRecordResult::Record => {
if self.current_field != self.num_columns {
return Err(ArrowError::CsvError(format!(
"incorrect number of fields for line {}, expected {} got {}",
self.line_number, self.num_columns, self.current_field
)));
if self.allow_truncated_rows && self.current_field < self.num_columns {
// If the number of fields is less than expected, pad with nulls
let fill_count = self.num_columns - self.current_field;
let fill_value = self.offsets[self.offsets_len - 1];
self.offsets[self.offsets_len..self.offsets_len + fill_count]
.fill(fill_value);
self.offsets_len += fill_count;
} else {
return Err(ArrowError::CsvError(format!(
"incorrect number of fields for line {}, expected {} got {}",
self.line_number, self.num_columns, self.current_field
)));
}
}
read += 1;
self.current_field = 0;
Expand Down Expand Up @@ -172,6 +188,11 @@ impl RecordDecoder {
self.num_rows = 0;
}

/// Sets the decoder to allow rows with less than the expected number columns
pub fn set_allow_truncated_rows(&mut self, allow: bool) {
self.allow_truncated_rows = allow;
}

/// Flushes the current contents of the reader
pub fn flush(&mut self) -> Result<StringRecords<'_>, ArrowError> {
if self.current_field != 0 {
Expand Down Expand Up @@ -359,4 +380,14 @@ mod tests {
assert_eq!(read, 2);
assert_eq!(bytes, csv.len());
}

#[test]
fn test_truncated_rows() {
let csv = "a,b\nv\n,1\n,2\n,3\n";
let mut decoder = RecordDecoder::new(Reader::new(), 2);
decoder.set_allow_truncated_rows(true);
let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
assert_eq!(read, 5);
assert_eq!(bytes, csv.len());
}
}
7 changes: 7 additions & 0 deletions arrow-csv/test/data/truncated_rows.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Name,Age,Occupation,DOB
A1,34,Engineer,1985-07-16
B2,29,Doctor
,,
C3,45,Artist,1964-11-30
D4
E5,31,,

0 comments on commit d807a4e

Please sign in to comment.