Skip to content

Commit

Permalink
csv: Add option to specify custom null values
Browse files Browse the repository at this point in the history
Can specify custom strings as `NULL` values for CSVs. This allows
reading a CSV files which have placeholders for NULL values instead of
empty strings.

Fixes #4794

Signed-off-by: Vaibhav <[email protected]>
  • Loading branch information
vrongmeal committed Sep 7, 2023
1 parent 15dde87 commit ebbf1fb
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 27 deletions.
138 changes: 111 additions & 27 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ use chrono::{TimeZone, Utc};
use csv::StringRecord;
use lazy_static::lazy_static;
use regex::RegexSet;
use std::collections::HashSet;
use std::fmt;
use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
Expand Down Expand Up @@ -213,6 +214,7 @@ pub struct Format {
escape: Option<u8>,
quote: Option<u8>,
terminator: Option<u8>,
nulls: HashSet<String>,
}

impl Format {
Expand Down Expand Up @@ -241,6 +243,11 @@ impl Format {
self
}

pub fn with_nulls(mut self, nulls: HashSet<String>) -> Self {
self.nulls = nulls;
self
}

/// Infer schema of CSV records from the provided `reader`
///
/// If `max_records` is `None`, all records will be read, otherwise up to `max_records`
Expand Down Expand Up @@ -557,6 +564,9 @@ pub struct Decoder {

/// A decoder for [`StringRecords`]
record_decoder: RecordDecoder,

/// Values to consider as null.
nulls: HashSet<String>,
}

impl Decoder {
Expand Down Expand Up @@ -603,6 +613,7 @@ impl Decoder {
Some(self.schema.metadata.clone()),
self.projection.as_ref(),
self.line_number,
&self.nulls,
)?;
self.line_number += rows.len();
Ok(Some(batch))
Expand All @@ -621,26 +632,30 @@ fn parse(
metadata: Option<std::collections::HashMap<String, String>>,
projection: Option<&Vec<usize>>,
line_number: usize,
nulls: &HashSet<String>,
) -> Result<RecordBatch, ArrowError> {
let projection: Vec<usize> = match projection {
Some(v) => v.clone(),
None => fields.iter().enumerate().map(|(i, _)| i).collect(),
};

let is_null = |s: &str| -> bool { s.is_empty() || nulls.contains(s) };

let arrays: Result<Vec<ArrayRef>, _> = projection
.iter()
.map(|i| {
let i = *i;
let field = &fields[i];
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
DataType::Boolean => build_boolean_array(line_number, rows, i, is_null),
DataType::Decimal128(precision, scale) => {
build_decimal_array::<Decimal128Type>(
line_number,
rows,
i,
*precision,
*scale,
is_null,
)
}
DataType::Decimal256(precision, scale) => {
Expand All @@ -650,60 +665,78 @@ fn parse(
i,
*precision,
*scale,
is_null,
)
}
DataType::Int8 => build_primitive_array::<Int8Type>(line_number, rows, i),
DataType::Int8 => {
build_primitive_array::<Int8Type>(line_number, rows, i, is_null)
}
DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i)
build_primitive_array::<Int16Type>(line_number, rows, i, is_null)
}
DataType::Int32 => {
build_primitive_array::<Int32Type>(line_number, rows, i)
build_primitive_array::<Int32Type>(line_number, rows, i, is_null)
}
DataType::Int64 => {
build_primitive_array::<Int64Type>(line_number, rows, i)
build_primitive_array::<Int64Type>(line_number, rows, i, is_null)
}
DataType::UInt8 => {
build_primitive_array::<UInt8Type>(line_number, rows, i)
build_primitive_array::<UInt8Type>(line_number, rows, i, is_null)
}
DataType::UInt16 => {
build_primitive_array::<UInt16Type>(line_number, rows, i)
build_primitive_array::<UInt16Type>(line_number, rows, i, is_null)
}
DataType::UInt32 => {
build_primitive_array::<UInt32Type>(line_number, rows, i)
build_primitive_array::<UInt32Type>(line_number, rows, i, is_null)
}
DataType::UInt64 => {
build_primitive_array::<UInt64Type>(line_number, rows, i)
build_primitive_array::<UInt64Type>(line_number, rows, i, is_null)
}
DataType::Float32 => {
build_primitive_array::<Float32Type>(line_number, rows, i)
build_primitive_array::<Float32Type>(line_number, rows, i, is_null)
}
DataType::Float64 => {
build_primitive_array::<Float64Type>(line_number, rows, i)
build_primitive_array::<Float64Type>(line_number, rows, i, is_null)
}
DataType::Date32 => {
build_primitive_array::<Date32Type>(line_number, rows, i)
build_primitive_array::<Date32Type>(line_number, rows, i, is_null)
}
DataType::Date64 => {
build_primitive_array::<Date64Type>(line_number, rows, i)
}
DataType::Time32(TimeUnit::Second) => {
build_primitive_array::<Time32SecondType>(line_number, rows, i)
build_primitive_array::<Date64Type>(line_number, rows, i, is_null)
}
DataType::Time32(TimeUnit::Second) => build_primitive_array::<
Time32SecondType,
>(
line_number, rows, i, is_null
),
DataType::Time32(TimeUnit::Millisecond) => {
build_primitive_array::<Time32MillisecondType>(line_number, rows, i)
build_primitive_array::<Time32MillisecondType>(
line_number,
rows,
i,
is_null,
)
}
DataType::Time64(TimeUnit::Microsecond) => {
build_primitive_array::<Time64MicrosecondType>(line_number, rows, i)
}
DataType::Time64(TimeUnit::Nanosecond) => {
build_primitive_array::<Time64NanosecondType>(line_number, rows, i)
build_primitive_array::<Time64MicrosecondType>(
line_number,
rows,
i,
is_null,
)
}
DataType::Time64(TimeUnit::Nanosecond) => build_primitive_array::<
Time64NanosecondType,
>(
line_number, rows, i, is_null
),
DataType::Timestamp(TimeUnit::Second, tz) => {
build_timestamp_array::<TimestampSecondType>(
line_number,
rows,
i,
tz.as_deref(),
is_null,
)
}
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
Expand All @@ -712,6 +745,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
)
}
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
Expand All @@ -720,6 +754,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
)
}
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
Expand All @@ -728,6 +763,7 @@ fn parse(
rows,
i,
tz.as_deref(),
is_null,
)
}
DataType::Utf8 => Ok(Arc::new(
Expand Down Expand Up @@ -827,11 +863,12 @@ fn build_decimal_array<T: DecimalType>(
col_idx: usize,
precision: u8,
scale: i8,
is_null: impl Fn(&str) -> bool,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
for row in rows.iter() {
let s = row.get(col_idx);
if s.is_empty() {
if is_null(s) {
// append null
decimal_builder.append_null();
} else {
Expand Down Expand Up @@ -859,12 +896,13 @@ fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
is_null: impl Fn(&str) -> bool,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if s.is_empty() {
if is_null(s) {
return Ok(None);
}

Expand All @@ -888,14 +926,17 @@ fn build_timestamp_array<T: ArrowTimestampType>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: Option<&str>,
is_null: impl Fn(&str) -> bool,
) -> Result<ArrayRef, ArrowError> {
Ok(Arc::new(match timezone {
Some(timezone) => {
let tz: Tz = timezone.parse()?;
build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz)?
build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &tz, is_null)?
.with_timezone(timezone)
}
None => build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc)?,
None => {
build_timestamp_array_impl::<T, _>(line_number, rows, col_idx, &Utc, is_null)?
}
}))
}

Expand All @@ -904,12 +945,13 @@ fn build_timestamp_array_impl<T: ArrowTimestampType, Tz: TimeZone>(
rows: &StringRecords<'_>,
col_idx: usize,
timezone: &Tz,
is_null: impl Fn(&str) -> bool,
) -> Result<PrimitiveArray<T>, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if s.is_empty() {
if is_null(s) {
return Ok(None);
}

Expand All @@ -936,12 +978,13 @@ fn build_boolean_array(
line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
is_null: impl Fn(&str) -> bool,
) -> Result<ArrayRef, ArrowError> {
rows.iter()
.enumerate()
.map(|(row_index, row)| {
let s = row.get(col_idx);
if s.is_empty() {
if is_null(s) {
return Ok(None);
}
let parsed = parse_bool(s);
Expand Down Expand Up @@ -975,6 +1018,8 @@ pub struct ReaderBuilder {
bounds: Bounds,
/// Optional projection for which columns to load (zero-based column indices)
projection: Option<Vec<usize>>,
/// Strings to consider as `NULL` when parsing.
nulls: HashSet<String>,
}

impl ReaderBuilder {
Expand Down Expand Up @@ -1006,6 +1051,7 @@ impl ReaderBuilder {
batch_size: 1024,
bounds: None,
projection: None,
nulls: HashSet::new(),
}
}

Expand Down Expand Up @@ -1042,6 +1088,11 @@ impl ReaderBuilder {
self
}

pub fn with_nulls(mut self, nulls: HashSet<String>) -> Self {
self.nulls = nulls;
self
}

/// Set the batch size (number of records to load at one time)
pub fn with_batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = batch_size;
Expand Down Expand Up @@ -1100,6 +1151,7 @@ impl ReaderBuilder {
end,
projection: self.projection,
batch_size: self.batch_size,
nulls: self.nulls,
}
}
}
Expand Down Expand Up @@ -1426,6 +1478,38 @@ mod tests {
assert!(!batch.column(1).is_null(4));
}

#[test]
fn test_custom_nulls() {
let schema = Arc::new(Schema::new(vec![
Field::new("c_int", DataType::UInt64, true),
Field::new("c_float", DataType::Float32, true),
Field::new("c_string", DataType::Utf8, true),
Field::new("c_bool", DataType::Boolean, true),
]));

let file = File::open("test/data/custom_null_test.csv").unwrap();

let nulls: HashSet<String> = ["nil"].into_iter().map(|s| s.to_string()).collect();

let mut csv = ReaderBuilder::new(schema)
.has_header(true)
.with_nulls(nulls)
.build(file)
.unwrap();

let batch = csv.next().unwrap().unwrap();

// "nil"s should be NULL
assert!(batch.column(0).is_null(1));
assert!(batch.column(1).is_null(2));
assert!(batch.column(3).is_null(4));
// Standard empty (NULL) "".
assert!(batch.column(1).is_null(0));
// String won't be empty
assert!(!batch.column(2).is_null(3));
assert!(!batch.column(2).is_null(4));
}

#[test]
fn test_nulls_with_inference() {
let mut file = File::open("test/data/various_types.csv").unwrap();
Expand Down
6 changes: 6 additions & 0 deletions arrow-csv/test/data/custom_null_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
c_int,c_float,c_string,c_bool
1,,"1.11",True
nil,2.2,"2.22",TRUE
3,nil,"3.33",true
4,4.4,nil,False
5,6.6,"",nil

0 comments on commit ebbf1fb

Please sign in to comment.