diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 328c2cd41f3b..19a7aef74529 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -134,6 +134,7 @@ use chrono::{TimeZone, Utc}; use csv::StringRecord; use lazy_static::lazy_static; use regex::RegexSet; +use std::collections::HashSet; use std::fmt; use std::fs::File; use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; @@ -213,6 +214,7 @@ pub struct Format { escape: Option, quote: Option, terminator: Option, + nulls: HashSet, } impl Format { @@ -241,6 +243,11 @@ impl Format { self } + pub fn with_nulls(mut self, nulls: HashSet) -> Self { + self.nulls = nulls; + self + } + /// Infer schema of CSV records from the provided `reader` /// /// If `max_records` is `None`, all records will be read, otherwise up to `max_records` @@ -557,6 +564,9 @@ pub struct Decoder { /// A decoder for [`StringRecords`] record_decoder: RecordDecoder, + + /// Values to consider as null. + nulls: HashSet, } impl Decoder { @@ -603,6 +613,7 @@ impl Decoder { Some(self.schema.metadata.clone()), self.projection.as_ref(), self.line_number, + &self.nulls, )?; self.line_number += rows.len(); Ok(Some(batch)) @@ -621,19 +632,22 @@ fn parse( metadata: Option>, projection: Option<&Vec>, line_number: usize, + nulls: &HashSet, ) -> Result { let projection: Vec = match projection { Some(v) => v.clone(), None => fields.iter().enumerate().map(|(i, _)| i).collect(), }; + let is_null = |s: &str| -> bool { s.is_empty() || nulls.contains(s) }; + let arrays: Result, _> = projection .iter() .map(|i| { let i = *i; let field = &fields[i]; match field.data_type() { - DataType::Boolean => build_boolean_array(line_number, rows, i), + DataType::Boolean => build_boolean_array(line_number, rows, i, is_null), DataType::Decimal128(precision, scale) => { build_decimal_array::( line_number, @@ -641,6 +655,7 @@ fn parse( i, *precision, *scale, + is_null, ) } DataType::Decimal256(precision, scale) => { @@ -650,60 +665,78 @@ fn parse( i, *precision, *scale, + is_null, ) } - DataType::Int8 => build_primitive_array::(line_number, rows, i), + DataType::Int8 => { + build_primitive_array::(line_number, rows, i, is_null) + } DataType::Int16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Int32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Int64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::UInt8 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::UInt16 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::UInt32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::UInt64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Float32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Float64 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Date32 => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } DataType::Date64 => { - build_primitive_array::(line_number, rows, i) - } - DataType::Time32(TimeUnit::Second) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::(line_number, rows, i, is_null) } + DataType::Time32(TimeUnit::Second) => build_primitive_array::< + Time32SecondType, + >( + line_number, rows, i, is_null + ), DataType::Time32(TimeUnit::Millisecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + is_null, + ) } DataType::Time64(TimeUnit::Microsecond) => { - build_primitive_array::(line_number, rows, i) - } - DataType::Time64(TimeUnit::Nanosecond) => { - build_primitive_array::(line_number, rows, i) + build_primitive_array::( + line_number, + rows, + i, + is_null, + ) } + DataType::Time64(TimeUnit::Nanosecond) => build_primitive_array::< + Time64NanosecondType, + >( + line_number, rows, i, is_null + ), DataType::Timestamp(TimeUnit::Second, tz) => { build_timestamp_array::( line_number, rows, i, tz.as_deref(), + is_null, ) } DataType::Timestamp(TimeUnit::Millisecond, tz) => { @@ -712,6 +745,7 @@ fn parse( rows, i, tz.as_deref(), + is_null, ) } DataType::Timestamp(TimeUnit::Microsecond, tz) => { @@ -720,6 +754,7 @@ fn parse( rows, i, tz.as_deref(), + is_null, ) } DataType::Timestamp(TimeUnit::Nanosecond, tz) => { @@ -728,6 +763,7 @@ fn parse( rows, i, tz.as_deref(), + is_null, ) } DataType::Utf8 => Ok(Arc::new( @@ -827,11 +863,12 @@ fn build_decimal_array( col_idx: usize, precision: u8, scale: i8, + is_null: impl Fn(&str) -> bool, ) -> Result { let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); for row in rows.iter() { let s = row.get(col_idx); - if s.is_empty() { + if is_null(s) { // append null decimal_builder.append_null(); } else { @@ -859,12 +896,13 @@ fn build_primitive_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + is_null: impl Fn(&str) -> bool, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if is_null(s) { return Ok(None); } @@ -888,14 +926,17 @@ fn build_timestamp_array( rows: &StringRecords<'_>, col_idx: usize, timezone: Option<&str>, + is_null: impl Fn(&str) -> bool, ) -> Result { Ok(Arc::new(match timezone { Some(timezone) => { let tz: Tz = timezone.parse()?; - build_timestamp_array_impl::(line_number, rows, col_idx, &tz)? + build_timestamp_array_impl::(line_number, rows, col_idx, &tz, is_null)? .with_timezone(timezone) } - None => build_timestamp_array_impl::(line_number, rows, col_idx, &Utc)?, + None => { + build_timestamp_array_impl::(line_number, rows, col_idx, &Utc, is_null)? + } })) } @@ -904,12 +945,13 @@ fn build_timestamp_array_impl( rows: &StringRecords<'_>, col_idx: usize, timezone: &Tz, + is_null: impl Fn(&str) -> bool, ) -> Result, ArrowError> { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if is_null(s) { return Ok(None); } @@ -936,12 +978,13 @@ fn build_boolean_array( line_number: usize, rows: &StringRecords<'_>, col_idx: usize, + is_null: impl Fn(&str) -> bool, ) -> Result { rows.iter() .enumerate() .map(|(row_index, row)| { let s = row.get(col_idx); - if s.is_empty() { + if is_null(s) { return Ok(None); } let parsed = parse_bool(s); @@ -975,6 +1018,8 @@ pub struct ReaderBuilder { bounds: Bounds, /// Optional projection for which columns to load (zero-based column indices) projection: Option>, + /// Strings to consider as `NULL` when parsing. + nulls: HashSet, } impl ReaderBuilder { @@ -1006,6 +1051,7 @@ impl ReaderBuilder { batch_size: 1024, bounds: None, projection: None, + nulls: HashSet::new(), } } @@ -1042,6 +1088,11 @@ impl ReaderBuilder { self } + pub fn with_nulls(mut self, nulls: HashSet) -> Self { + self.nulls = nulls; + self + } + /// Set the batch size (number of records to load at one time) pub fn with_batch_size(mut self, batch_size: usize) -> Self { self.batch_size = batch_size; @@ -1100,6 +1151,7 @@ impl ReaderBuilder { end, projection: self.projection, batch_size: self.batch_size, + nulls: self.nulls, } } } @@ -1426,6 +1478,38 @@ mod tests { assert!(!batch.column(1).is_null(4)); } + #[test] + fn test_custom_nulls() { + let schema = Arc::new(Schema::new(vec![ + Field::new("c_int", DataType::UInt64, true), + Field::new("c_float", DataType::Float32, true), + Field::new("c_string", DataType::Utf8, true), + Field::new("c_bool", DataType::Boolean, true), + ])); + + let file = File::open("test/data/custom_null_test.csv").unwrap(); + + let nulls: HashSet = ["nil"].into_iter().map(|s| s.to_string()).collect(); + + let mut csv = ReaderBuilder::new(schema) + .has_header(true) + .with_nulls(nulls) + .build(file) + .unwrap(); + + let batch = csv.next().unwrap().unwrap(); + + // "nil"s should be NULL + assert!(batch.column(0).is_null(1)); + assert!(batch.column(1).is_null(2)); + assert!(batch.column(3).is_null(4)); + // Standard empty (NULL) "". + assert!(batch.column(1).is_null(0)); + // String won't be empty + assert!(!batch.column(2).is_null(3)); + assert!(!batch.column(2).is_null(4)); + } + #[test] fn test_nulls_with_inference() { let mut file = File::open("test/data/various_types.csv").unwrap(); diff --git a/arrow-csv/test/data/custom_null_test.csv b/arrow-csv/test/data/custom_null_test.csv new file mode 100644 index 000000000000..30d7b7f2a1bf --- /dev/null +++ b/arrow-csv/test/data/custom_null_test.csv @@ -0,0 +1,6 @@ +c_int,c_float,c_string,c_bool +1,,"1.11",True +nil,2.2,"2.22",TRUE +3,nil,"3.33",true +4,4.4,nil,False +5,6.6,"",nil \ No newline at end of file