From 4045fb5ba2dae222aab7365250d788ae8872e229 Mon Sep 17 00:00:00 2001 From: Posnet Date: Wed, 8 May 2024 19:47:57 +1000 Subject: [PATCH] csv: Add support for flexible column lengths (#5679) * Add support for truncated rows Similar to what is supported in the csv crate, as well as the pandas, arrow-cpp and polars crates. A subset of CSV files treat missing columns at the end of rows as null (if the schema allows it). This commit adds support to optionally enable treating such missing columns as null. The default behavior is still to treat an incorrect number of columns as an error. * Add truncated rows to `RecordDecoder::new` Instead of using a setter, truncated rows is passwed into the `new` method for RecordDecoder since it is not part of the public API. --- arrow-csv/src/reader/mod.rs | 157 ++++++++++++++++++++++++- arrow-csv/src/reader/records.rs | 43 +++++-- arrow-csv/test/data/truncated_rows.csv | 8 ++ 3 files changed, 198 insertions(+), 10 deletions(-) create mode 100644 arrow-csv/test/data/truncated_rows.csv diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 4d7d60e10cf6..09087ca31958 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -231,6 +231,7 @@ pub struct Format { quote: Option, terminator: Option, null_regex: NullRegex, + truncated_rows: bool, } impl Format { @@ -265,6 +266,17 @@ impl Format { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.truncated_rows = allow; + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -329,6 +341,7 @@ impl Format { fn build_reader(&self, reader: R) -> csv::Reader { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(self.header); + builder.flexible(self.truncated_rows); if let Some(c) = self.delimiter { builder.delimiter(c); @@ -1121,6 +1134,17 @@ impl ReaderBuilder { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_truncated_rows(mut self, allow: bool) -> Self { + self.format.truncated_rows = allow; + self + } + /// Create a new `Reader` from a non-buffered reader /// /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional @@ -1140,7 +1164,11 @@ impl ReaderBuilder { /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream pub fn build_decoder(self) -> Decoder { let delimiter = self.format.build_parser(); - let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); + let record_decoder = RecordDecoder::new( + delimiter, + self.schema.fields().len(), + self.format.truncated_rows, + ); let header = self.format.header as usize; @@ -2164,6 +2192,133 @@ mod tests { assert!(c.is_null(3)); } + #[test] + fn test_truncated_rows() { + let data = "a,b,c\n1,2,3\n4,5\n\n6,7,8"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(batches.is_ok()); + let batch = batches.unwrap().into_iter().next().unwrap(); + // Empty rows are skipped by the underlying csv parser + assert_eq!(batch.num_rows(), 3); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(false) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"), + _ => false, + }); + } + + #[test] + fn test_truncated_rows_csv() { + let file = File::open("test/data/truncated_rows.csv").unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("Name", DataType::Utf8, true), + Field::new("Age", DataType::UInt32, true), + Field::new("Occupation", DataType::Utf8, true), + Field::new("DOB", DataType::Date32, true), + ])); + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_batch_size(24) + .with_truncated_rows(true); + let csv = reader.build(file).unwrap(); + let batches = csv.collect::, _>>().unwrap(); + + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 4); + let name = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let age = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let occupation = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let dob = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(name.value(0), "A1"); + assert_eq!(name.value(1), "B2"); + assert!(name.is_null(2)); + assert_eq!(name.value(3), "C3"); + assert_eq!(name.value(4), "D4"); + assert_eq!(name.value(5), "E5"); + + assert_eq!(age.value(0), 34); + assert_eq!(age.value(1), 29); + assert!(age.is_null(2)); + assert_eq!(age.value(3), 45); + assert!(age.is_null(4)); + assert_eq!(age.value(5), 31); + + assert_eq!(occupation.value(0), "Engineer"); + assert_eq!(occupation.value(1), "Doctor"); + assert!(occupation.is_null(2)); + assert_eq!(occupation.value(3), "Artist"); + assert!(occupation.is_null(4)); + assert!(occupation.is_null(5)); + + assert_eq!(dob.value(0), 5675); + assert!(dob.is_null(1)); + assert!(dob.is_null(2)); + assert_eq!(dob.value(3), -1858); + assert!(dob.is_null(4)); + assert!(dob.is_null(5)); + } + + #[test] + fn test_truncated_rows_not_nullable_error() { + let data = "a,b,c\n1,2,3\n4,5"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::InvalidArgumentError(e)) => + e.to_string().contains("contains null values"), + _ => false, + }); + } + #[test] fn test_buffered() { let tests = [ diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs index 877cfb3ee653..a07fc9c94ffa 100644 --- a/arrow-csv/src/reader/records.rs +++ b/arrow-csv/src/reader/records.rs @@ -56,10 +56,16 @@ pub struct RecordDecoder { /// /// We track this independently of Vec to avoid re-zeroing memory data_len: usize, + + /// Whether rows with less than expected columns are considered valid + /// + /// Default value is false + /// When enabled fills in missing columns with null + truncated_rows: bool, } impl RecordDecoder { - pub fn new(delimiter: Reader, num_columns: usize) -> Self { + pub fn new(delimiter: Reader, num_columns: usize, truncated_rows: bool) -> Self { Self { delimiter, num_columns, @@ -70,6 +76,7 @@ impl RecordDecoder { data_len: 0, data: vec![], num_rows: 0, + truncated_rows, } } @@ -127,10 +134,19 @@ impl RecordDecoder { } ReadRecordResult::Record => { if self.current_field != self.num_columns { - return Err(ArrowError::CsvError(format!( - "incorrect number of fields for line {}, expected {} got {}", - self.line_number, self.num_columns, self.current_field - ))); + if self.truncated_rows && self.current_field < self.num_columns { + // If the number of fields is less than expected, pad with nulls + let fill_count = self.num_columns - self.current_field; + let fill_value = self.offsets[self.offsets_len - 1]; + self.offsets[self.offsets_len..self.offsets_len + fill_count] + .fill(fill_value); + self.offsets_len += fill_count; + } else { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); + } } read += 1; self.current_field = 0; @@ -299,7 +315,7 @@ mod tests { .into_iter(); let mut reader = BufReader::with_capacity(3, Cursor::new(csv.as_bytes())); - let mut decoder = RecordDecoder::new(Reader::new(), 3); + let mut decoder = RecordDecoder::new(Reader::new(), 3, false); loop { let to_read = 3; @@ -333,7 +349,7 @@ mod tests { #[test] fn test_invalid_fields() { let csv = "a,b\nb,c\na\n"; - let mut decoder = RecordDecoder::new(Reader::new(), 2); + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); let err = decoder.decode(csv.as_bytes(), 4).unwrap_err().to_string(); let expected = "Csv error: incorrect number of fields for line 3, expected 2 got 1"; @@ -341,7 +357,7 @@ mod tests { assert_eq!(err, expected); // Test with initial skip - let mut decoder = RecordDecoder::new(Reader::new(), 2); + let mut decoder = RecordDecoder::new(Reader::new(), 2, false); let (skipped, bytes) = decoder.decode(csv.as_bytes(), 1).unwrap(); assert_eq!(skipped, 1); decoder.clear(); @@ -354,9 +370,18 @@ mod tests { #[test] fn test_skip_insufficient_rows() { let csv = "a\nv\n"; - let mut decoder = RecordDecoder::new(Reader::new(), 1); + let mut decoder = RecordDecoder::new(Reader::new(), 1, false); let (read, bytes) = decoder.decode(csv.as_bytes(), 3).unwrap(); assert_eq!(read, 2); assert_eq!(bytes, csv.len()); } + + #[test] + fn test_truncated_rows() { + let csv = "a,b\nv\n,1\n,2\n,3\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2, true); + let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap(); + assert_eq!(read, 5); + assert_eq!(bytes, csv.len()); + } } diff --git a/arrow-csv/test/data/truncated_rows.csv b/arrow-csv/test/data/truncated_rows.csv new file mode 100644 index 000000000000..0b2af5740095 --- /dev/null +++ b/arrow-csv/test/data/truncated_rows.csv @@ -0,0 +1,8 @@ +Name,Age,Occupation,DOB +A1,34,Engineer,1985-07-16 +B2,29,Doctor +, +C3,45,Artist,1964-11-30 + +D4 +E5,31,,