diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 4d7d60e10cf6..542a8908d349 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -231,6 +231,7 @@ pub struct Format { quote: Option, terminator: Option, null_regex: NullRegex, + allow_truncated_rows: bool, } impl Format { @@ -265,6 +266,17 @@ impl Format { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_allow_truncated_rows(mut self, allow: bool) -> Self { + self.allow_truncated_rows = allow; + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -329,6 +341,7 @@ impl Format { fn build_reader(&self, reader: R) -> csv::Reader { let mut builder = csv::ReaderBuilder::new(); builder.has_headers(self.header); + builder.flexible(self.allow_truncated_rows); if let Some(c) = self.delimiter { builder.delimiter(c); @@ -1121,6 +1134,17 @@ impl ReaderBuilder { self } + /// Whether to allow truncated rows when parsing. + /// + /// By default this is set to `false` and will error if the CSV rows have different lengths. + /// When set to true then it will allow records with less than the expected number of columns + /// and fill the missing columns with nulls. If the record's schema is not nullable, then it + /// will still return an error. + pub fn with_allow_truncated_rows(mut self, allow: bool) -> Self { + self.format.allow_truncated_rows = allow; + self + } + /// Create a new `Reader` from a non-buffered reader /// /// If `R: BufRead` consider using [`Self::build_buffered`] to avoid unnecessary additional @@ -1140,7 +1164,9 @@ impl ReaderBuilder { /// Builds a decoder that can be used to decode CSV from an arbitrary byte stream pub fn build_decoder(self) -> Decoder { let delimiter = self.format.build_parser(); - let record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); + let mut record_decoder = RecordDecoder::new(delimiter, self.schema.fields().len()); + + record_decoder.set_allow_truncated_rows(self.format.allow_truncated_rows); let header = self.format.header as usize; @@ -2164,6 +2190,130 @@ mod tests { assert!(c.is_null(3)); } + #[test] + fn test_allow_truncated_rows() { + let data = "a,b,c\n1,2,3\n4,5"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + Field::new("c", DataType::Int32, true), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_allow_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(batches.is_ok()); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_allow_truncated_rows(false) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::CsvError(e)) => e.to_string().contains("incorrect number of fields"), + _ => false, + }); + } + + #[test] + fn test_allow_truncated_rows_csv() { + let file = File::open("test/data/truncated_rows.csv").unwrap(); + let schema = Arc::new(Schema::new(vec![ + Field::new("Name", DataType::Utf8, true), + Field::new("Age", DataType::UInt32, true), + Field::new("Occupation", DataType::Utf8, true), + Field::new("DOB", DataType::Date32, true), + ])); + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_batch_size(24) + .with_allow_truncated_rows(true); + let csv = reader.build(file).unwrap(); + let batches = csv.collect::, _>>().unwrap(); + + assert_eq!(batches.len(), 1); + let batch = &batches[0]; + assert_eq!(batch.num_rows(), 6); + assert_eq!(batch.num_columns(), 4); + let name = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let age = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let occupation = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + let dob = batch + .column(3) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(name.value(0), "A1"); + assert_eq!(name.value(1), "B2"); + assert!(name.is_null(2)); + assert_eq!(name.value(3), "C3"); + assert_eq!(name.value(4), "D4"); + assert_eq!(name.value(5), "E5"); + + assert_eq!(age.value(0), 34); + assert_eq!(age.value(1), 29); + assert!(age.is_null(2)); + assert_eq!(age.value(3), 45); + assert!(age.is_null(4)); + assert_eq!(age.value(5), 31); + + assert_eq!(occupation.value(0), "Engineer"); + assert_eq!(occupation.value(1), "Doctor"); + assert!(occupation.is_null(2)); + assert_eq!(occupation.value(3), "Artist"); + assert!(occupation.is_null(4)); + assert!(occupation.is_null(5)); + + assert_eq!(dob.value(0), 5675); + assert!(dob.is_null(1)); + assert!(dob.is_null(2)); + assert_eq!(dob.value(3), -1858); + assert!(dob.is_null(4)); + assert!(dob.is_null(5)); + } + + #[test] + fn test_allow_truncated_rows_not_nullable_error() { + let data = "a,b,c\n1,2,3\n4,5"; + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + let reader = ReaderBuilder::new(schema.clone()) + .with_header(true) + .with_allow_truncated_rows(true) + .build(Cursor::new(data)) + .unwrap(); + + let batches = reader.collect::, _>>(); + assert!(match batches { + Err(ArrowError::InvalidArgumentError(e)) => + e.to_string().contains("contains null values"), + _ => false, + }); + } + #[test] fn test_buffered() { let tests = [ diff --git a/arrow-csv/src/reader/records.rs b/arrow-csv/src/reader/records.rs index 877cfb3ee653..63d6da7354a9 100644 --- a/arrow-csv/src/reader/records.rs +++ b/arrow-csv/src/reader/records.rs @@ -56,6 +56,12 @@ pub struct RecordDecoder { /// /// We track this independently of Vec to avoid re-zeroing memory data_len: usize, + + /// Whether rows with less than expected columns are considered valid + /// + /// Default value is false + /// When enabled fills in missing columns with null + allow_truncated_rows: bool, } impl RecordDecoder { @@ -70,6 +76,7 @@ impl RecordDecoder { data_len: 0, data: vec![], num_rows: 0, + allow_truncated_rows: false, } } @@ -127,10 +134,19 @@ impl RecordDecoder { } ReadRecordResult::Record => { if self.current_field != self.num_columns { - return Err(ArrowError::CsvError(format!( - "incorrect number of fields for line {}, expected {} got {}", - self.line_number, self.num_columns, self.current_field - ))); + if self.allow_truncated_rows && self.current_field < self.num_columns { + // If the number of fields is less than expected, pad with nulls + let fill_count = self.num_columns - self.current_field; + let fill_value = self.offsets[self.offsets_len - 1]; + self.offsets[self.offsets_len..self.offsets_len + fill_count] + .fill(fill_value); + self.offsets_len += fill_count; + } else { + return Err(ArrowError::CsvError(format!( + "incorrect number of fields for line {}, expected {} got {}", + self.line_number, self.num_columns, self.current_field + ))); + } } read += 1; self.current_field = 0; @@ -172,6 +188,11 @@ impl RecordDecoder { self.num_rows = 0; } + /// Sets the decoder to allow rows with less than the expected number columns + pub fn set_allow_truncated_rows(&mut self, allow: bool) { + self.allow_truncated_rows = allow; + } + /// Flushes the current contents of the reader pub fn flush(&mut self) -> Result, ArrowError> { if self.current_field != 0 { @@ -359,4 +380,14 @@ mod tests { assert_eq!(read, 2); assert_eq!(bytes, csv.len()); } + + #[test] + fn test_truncated_rows() { + let csv = "a,b\nv\n,1\n,2\n,3\n"; + let mut decoder = RecordDecoder::new(Reader::new(), 2); + decoder.set_allow_truncated_rows(true); + let (read, bytes) = decoder.decode(csv.as_bytes(), 5).unwrap(); + assert_eq!(read, 5); + assert_eq!(bytes, csv.len()); + } } diff --git a/arrow-csv/test/data/truncated_rows.csv b/arrow-csv/test/data/truncated_rows.csv new file mode 100644 index 000000000000..d3f80a07b468 --- /dev/null +++ b/arrow-csv/test/data/truncated_rows.csv @@ -0,0 +1,7 @@ +Name,Age,Occupation,DOB +A1,34,Engineer,1985-07-16 +B2,29,Doctor +,, +C3,45,Artist,1964-11-30 +D4 +E5,31,,