Skip to content

Commit

Permalink
csv: Add support for flexible column lengths (#5679)
Browse files Browse the repository at this point in the history
* Add support for truncated rows

Similar to what is supported in the csv crate, as well as the pandas,
arrow-cpp and polars crates. A subset of CSV files treat missing columns
at the end of rows as null (if the schema allows it). This commit adds
support to optionally enable treating such missing columns as null.
The default behavior is still to treat an incorrect number of columns
as an error.

* Add truncated rows to `RecordDecoder::new`

Instead of using a setter, truncated rows is passwed into
the `new` method for RecordDecoder since it is not part
of the public API.
  • Loading branch information
Posnet authored May 8, 2024
1 parent f67a5ce commit 4045fb5
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 10 deletions.
157 changes: 156 additions & 1 deletion arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ pub struct Format {
quote: Option<u8>,
terminator: Option<u8>,
null_regex: NullRegex,
truncated_rows: bool,
}

impl Format {
Expand Down Expand Up @@ -265,6 +266,17 @@ impl Format {
self
}

/// Whether to allow truncated rows when parsing.
///
/// By default this is set to `false` and will error if the CSV rows have different lengths.
/// When set to true then it will allow records with less than the expected number of columns
/// and fill the missing columns with nulls. If the record's schema is not nullable, then it
/// will still return an error.
pub fn with_truncated_rows(mut self, allow: bool) -> Self {
self.truncated_rows = allow;
self
}

/// Infer schema of CSV records from the provided `reader`
///
/// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
Expand Down Expand Up @@ -329,6 +341,7 @@ impl Format {
fn build_reader<R: Read>(&self, reader: R) -> csv::Reader<R> {
let mut builder = csv::ReaderBuilder::new();
builder.has_headers(self.header);
builder.flexible(self.truncated_rows);

if let Some(c) = self.delimiter {
builder.delimiter(c);
Expand Down Expand Up @@ -1121,6 +1134,17 @@ impl ReaderBuilder {
self
}

/// Whether to allow truncated rows when parsing.
///
/// By default this is set to `false` and will error if the CSV rows have different lengths.
/// When set to true then it will allow records with less than the expected number of columns
/// and fill the missing columns with nulls. If the record's schema is not nullable, then it
/// will still return an error.
pub fn with_truncated_rows(mut self, allow: bool) -> Self {
self.format.truncated_rows = allow;
self
}

/// Create a new `Reader` from a non-buffered reader
///
/// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional
Expand All @@ -1140,7 +1164,11 @@ impl ReaderBuilder {
/// Builds a decoder that can be used to decode CSV from an arbitrary byte stream
pub fn build_decoder(self) -> Decoder {
let delimiter = self.format.build_parser();
let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len());
let record_decoder = RecordDecoder::new(
delimiter,
self.schema.fields().len(),
self.format.truncated_rows,
);

let header = self.format.header as usize;

Expand Down Expand Up @@ -2164,6 +2192,133 @@ mod tests {
assert!(c.is_null(3));
}

#[test]
fn test_truncated_rows() {
let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8";
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, true),
Field::new("b", DataType::Int32, true),
Field::new("c", DataType::Int32, true),
]));

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_truncated_rows(true)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(batches.is_ok());
let batch = batches.unwrap().into_iter().next().unwrap();
// Empty rows are skipped by the underlying csv parser
assert_eq!(batch.num_rows(), 3);

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_truncated_rows(false)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(match batches {
Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"),
_ => false,
});
}

#[test]
fn test_truncated_rows_csv() {
let file = File::open("test/data/truncated_rows.csv").unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("Name", DataType::Utf8, true),
Field::new("Age", DataType::UInt32, true),
Field::new("Occupation", DataType::Utf8, true),
Field::new("DOB", DataType::Date32, true),
]));
let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_batch_size(24)
.with_truncated_rows(true);
let csv = reader.build(file).unwrap();
let batches = csv.collect::<Result<Vec<_>, _>>().unwrap();

assert_eq!(batches.len(), 1);
let batch = &batches[0];
assert_eq!(batch.num_rows(), 6);
assert_eq!(batch.num_columns(), 4);
let name = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let age = batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.unwrap();
let occupation = batch
.column(2)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let dob = batch
.column(3)
.as_any()
.downcast_ref::<Date32Array>()
.unwrap();

assert_eq!(name.value(0), "A1");
assert_eq!(name.value(1), "B2");
assert!(name.is_null(2));
assert_eq!(name.value(3), "C3");
assert_eq!(name.value(4), "D4");
assert_eq!(name.value(5), "E5");

assert_eq!(age.value(0), 34);
assert_eq!(age.value(1), 29);
assert!(age.is_null(2));
assert_eq!(age.value(3), 45);
assert!(age.is_null(4));
assert_eq!(age.value(5), 31);

assert_eq!(occupation.value(0), "Engineer");
assert_eq!(occupation.value(1), "Doctor");
assert!(occupation.is_null(2));
assert_eq!(occupation.value(3), "Artist");
assert!(occupation.is_null(4));
assert!(occupation.is_null(5));

assert_eq!(dob.value(0), 5675);
assert!(dob.is_null(1));
assert!(dob.is_null(2));
assert_eq!(dob.value(3), -1858);
assert!(dob.is_null(4));
assert!(dob.is_null(5));
}

#[test]
fn test_truncated_rows_not_nullable_error() {
let data = "a,b,c\n1,2,3\n4,5";
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Field::new("c", DataType::Int32, false),
]));

let reader = ReaderBuilder::new(schema.clone())
.with_header(true)
.with_truncated_rows(true)
.build(Cursor::new(data))
.unwrap();

let batches = reader.collect::<Result<Vec<_>, _>>();
assert!(match batches {
Err(ArrowError::InvalidArgumentError(e)) =>
e.to_string().contains("contains null values"),
_ => false,
});
}

#[test]
fn test_buffered() {
let tests = [
Expand Down
43 changes: 34 additions & 9 deletions arrow-csv/src/reader/records.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,16 @@ pub struct RecordDecoder {
///
/// We track this independently of Vec to avoid re-zeroing memory
data_len: usize,

/// Whether rows with less than expected columns are considered valid
///
/// Default value is false
/// When enabled fills in missing columns with null
truncated_rows: bool,
}

impl RecordDecoder {
pub fn new(delimiter: Reader, num_columns: usize) -> Self {
pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self {
Self {
delimiter,
num_columns,
Expand All @@ -70,6 +76,7 @@ impl RecordDecoder {
data_len: 0,
data: vec![],
num_rows: 0,
truncated_rows,
}
}

Expand Down Expand Up @@ -127,10 +134,19 @@ impl RecordDecoder {
}
ReadRecordResult::Record => {
if self.current_field != self.num_columns {
return Err(ArrowError::CsvError(format!(
"incorrect number of fields for line {}, expected {} got {}",
self.line_number, self.num_columns, self.current_field
)));
if self.truncated_rows && self.current_field < self.num_columns {
// If the number of fields is less than expected, pad with nulls
let fill_count = self.num_columns - self.current_field;
let fill_value = self.offsets[self.offsets_len - 1];
self.offsets[self.offsets_len..self.offsets_len + fill_count]
.fill(fill_value);
self.offsets_len += fill_count;
} else {
return Err(ArrowError::CsvError(format!(
"incorrect number of fields for line {}, expected {} got {}",
self.line_number, self.num_columns, self.current_field
)));
}
}
read += 1;
self.current_field = 0;
Expand Down Expand Up @@ -299,7 +315,7 @@ mod tests {
.into_iter();

let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes()));
let mut decoder = RecordDecoder::new(Reader::new(), 3);
let mut decoder = RecordDecoder::new(Reader::new(), 3, false);

loop {
let to_read = 3;
Expand Down Expand Up @@ -333,15 +349,15 @@ mod tests {
#[test]
fn test_invalid_fields() {
let csv = "a,b\nb,c\na\n";
let mut decoder = RecordDecoder::new(Reader::new(), 2);
let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string();

let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1";

assert_eq!(err, expected);

// Test with initial skip
let mut decoder = RecordDecoder::new(Reader::new(), 2);
let mut decoder = RecordDecoder::new(Reader::new(), 2, false);
let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap();
assert_eq!(skipped, 1);
decoder.clear();
Expand All @@ -354,9 +370,18 @@ mod tests {
#[test]
fn test_skip_insufficient_rows() {
let csv = "a\nv\n";
let mut decoder = RecordDecoder::new(Reader::new(), 1);
let mut decoder = RecordDecoder::new(Reader::new(), 1, false);
let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap();
assert_eq!(read, 2);
assert_eq!(bytes, csv.len());
}

#[test]
fn test_truncated_rows() {
let csv = "a,b\nv\n,1\n,2\n,3\n";
let mut decoder = RecordDecoder::new(Reader::new(), 2, true);
let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap();
assert_eq!(read, 5);
assert_eq!(bytes, csv.len());
}
}
8 changes: 8 additions & 0 deletions arrow-csv/test/data/truncated_rows.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Name,Age,Occupation,DOB
A1,34,Engineer,1985-07-16
B2,29,Doctor
,
C3,45,Artist,1964-11-30

D4
E5,31,,

0 comments on commit 4045fb5

Please sign in to comment.