From b155461f770eb2ab8cc5d3296f6123582cf5073d Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:34:02 +0000 Subject: [PATCH] Loosen nullability restrictions added in #3205 (#3226) (#3244) * Loosen nullability restrictions added in #3205 (#3226) * Fix tests * More test fixes * Yet more incorrect tests * Review feedback --- arrow-array/src/array/binary_array.rs | 2 +- arrow-array/src/array/mod.rs | 4 +- arrow-array/src/array/string_array.rs | 5 +- arrow-array/src/array/struct_array.rs | 30 ++++---- arrow-cast/src/cast.rs | 3 +- arrow-data/src/data.rs | 103 +++++++++++++++++++++++++- arrow-select/src/take.rs | 4 +- arrow/src/compute/kernels/limit.rs | 4 +- arrow/src/row/mod.rs | 35 +-------- 9 files changed, 133 insertions(+), 57 deletions(-) diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 0b526ecb3dee..3a30d748ee3a 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -531,7 +531,7 @@ mod tests { let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), + Field::new("item", DataType::UInt8, true), )); // [None, Some(b"Parquet")] diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 0f9a2ce59291..1e17e35d0f6d 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -915,8 +915,10 @@ mod tests { #[test] fn test_null_struct() { + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details let struct_type = - DataType::Struct(vec![Field::new("data", DataType::Int64, true)]); + DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); let array = new_null_array(&struct_type, 9); let a = array.as_any().downcast_ref::().unwrap(); diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index fb3bb23179b5..c8db589e3c28 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -608,8 +608,11 @@ mod tests { .unwrap(); let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap()); + + // It is possible to create a null struct containing a non-nullable child + // see https://github.com/apache/arrow-rs/pull/3244 for details let data_type = GenericListArray::::DATA_TYPE_CONSTRUCTOR(Box::new( - Field::new("item", DataType::UInt8, false), + Field::new("item", DataType::UInt8, true), )); // [None, Some(b"Parquet")] diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 7d88cc5c6deb..bf6489c1380c 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -227,13 +227,6 @@ impl From> for StructArray { field_value.data().data_type(), "the field data types must match the array data in a StructArray" ); - // Check nullability of child arrays - if !field_type.is_nullable() { - assert!( - field_value.null_count() == 0, - "non-nullable field cannot have null values" - ); - } }, ); @@ -241,6 +234,10 @@ impl From> for StructArray { .child_data(field_values.into_iter().map(|a| a.into_data()).collect()) .len(length); let array_data = unsafe { array_data.build_unchecked() }; + + // We must validate nullability + array_data.validate_nulls().unwrap(); + Self::from(array_data) } } @@ -283,13 +280,6 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { field_value.data().data_type(), "the field data types must match the array data in a StructArray" ); - // Check nullability of child arrays - if !field_type.is_nullable() { - assert!( - field_value.null_count() == 0, - "non-nullable field cannot have null values" - ); - } }, ); @@ -298,6 +288,10 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray { .child_data(field_values.into_iter().map(|a| a.into_data()).collect()) .len(length); let array_data = unsafe { array_data.build_unchecked() }; + + // We must validate nullability + array_data.validate_nulls().unwrap(); + Self::from(array_data) } } @@ -470,8 +464,8 @@ mod tests { .unwrap(); let field_types = vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), ]; let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) .len(5) @@ -568,7 +562,9 @@ mod tests { } #[test] - #[should_panic(expected = "non-nullable field cannot have null values")] + #[should_panic( + expected = "non-nullable child of type Int32 contains nulls not present in parent Struct" + )] fn test_struct_array_from_mismatched_nullability() { drop(StructArray::from(vec![( Field::new("c", DataType::Int32, false), diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 272a422eb114..7bb3aeb9603f 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -6594,7 +6594,8 @@ mod tests { cast_from_null_to_other(&data_type); // Cast null from and to struct - let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, true)]); + let data_type = + DataType::Struct(vec![Field::new("data", DataType::Int64, false)]); cast_from_null_to_other(&data_type); } diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index b230dfdb7564..b38321aacf4c 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -19,6 +19,7 @@ //! common attributes and operations for Arrow array. use crate::{bit_iterator::BitSliceIterator, bitmap::Bitmap}; +use arrow_buffer::bit_chunk_iterator::BitChunks; use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; use half::f16; @@ -975,6 +976,7 @@ impl ArrayData { /// see [`Self::validate_full`] pub fn validate_data(&self) -> Result<(), ArrowError> { self.validate()?; + self.validate_nulls()?; self.validate_values()?; Ok(()) @@ -1001,7 +1003,13 @@ impl ArrayData { Ok(()) } - /// Validates the the null count is correct + /// Validates the values stored within this [`ArrayData`] are valid + /// without recursing into child [`ArrayData`] + /// + /// Does not (yet) check + /// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85) + /// Validates the the null count is correct and that any + /// nullability requirements of its children are correct pub fn validate_nulls(&self) -> Result<(), ArrowError> { let nulls = self.null_buffer(); @@ -1012,9 +1020,102 @@ impl ArrayData { self.null_count, actual_null_count ))); } + + // In general non-nullable children should not contain nulls, however, for certain + // types, such as StructArray and FixedSizeList, nulls in the parent take up + // space in the child. As such we permit nulls in the children in the corresponding + // positions for such types + match &self.data_type { + DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => { + if !f.is_nullable() { + self.validate_non_nullable(None, 0, &self.child_data[0])? + } + } + DataType::FixedSizeList(field, len) => { + let child = &self.child_data[0]; + if !field.is_nullable() { + match nulls { + Some(nulls) => { + let element_len = *len as usize; + let mut buffer = + MutableBuffer::new_null(element_len * self.len); + + // Expand each bit within `null_mask` into `element_len` + // bits, constructing the implicit mask of the child elements + for i in 0..self.len { + if !bit_util::get_bit(nulls.as_ref(), self.offset + i) { + continue; + } + for j in 0..element_len { + bit_util::set_bit( + buffer.as_mut(), + i * element_len + j, + ) + } + } + let mask = buffer.into(); + self.validate_non_nullable(Some(&mask), 0, child)?; + } + None => self.validate_non_nullable(None, 0, child)?, + } + } + } + DataType::Struct(fields) => { + for (field, child) in fields.iter().zip(&self.child_data) { + if !field.is_nullable() { + self.validate_non_nullable(nulls, self.offset, child)? + } + } + } + _ => {} + } + Ok(()) } + /// Verifies that `child` contains no nulls not present in `mask` + fn validate_non_nullable( + &self, + mask: Option<&Buffer>, + offset: usize, + data: &ArrayData, + ) -> Result<(), ArrowError> { + let mask = match mask { + Some(mask) => mask.as_ref(), + None => return match data.null_count { + 0 => Ok(()), + _ => Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent {}", + data.data_type(), + self.data_type + ))), + }, + }; + + match data.null_buffer() { + Some(nulls) => { + let mask = BitChunks::new(mask, offset, data.len); + let nulls = BitChunks::new(nulls.as_ref(), data.offset, data.len); + mask + .iter() + .zip(nulls.iter()) + .chain(std::iter::once(( + mask.remainder_bits(), + nulls.remainder_bits(), + ))).try_for_each(|(m, c)| { + if (m & !c) != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "non-nullable child of type {} contains nulls not present in parent", + data.data_type() + ))) + } + Ok(()) + }) + } + None => Ok(()), + } + } + /// Validates the values stored within this [`ArrayData`] are valid /// without recursing into child [`ArrayData`] /// diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 857b6e3231ba..0b1d44319493 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -1603,7 +1603,7 @@ mod tests { let list_data_type = DataType::$list_data_type(Box::new(Field::new( "item", DataType::Int32, - false, + true, ))); let list_data = ArrayData::builder(list_data_type.clone()) .len(4) @@ -1676,7 +1676,7 @@ mod tests { let list_data_type = DataType::$list_data_type(Box::new(Field::new( "item", DataType::Int32, - false, + true, ))); let list_data = ArrayData::builder(list_data_type.clone()) .len(4) diff --git a/arrow/src/compute/kernels/limit.rs b/arrow/src/compute/kernels/limit.rs index 1f6c6aec5e1f..0d92e98cf718 100644 --- a/arrow/src/compute/kernels/limit.rs +++ b/arrow/src/compute/kernels/limit.rs @@ -158,8 +158,8 @@ mod tests { .unwrap(); let field_types = vec![ - Field::new("a", DataType::Boolean, false), - Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), ]; let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) .len(5) diff --git a/arrow/src/row/mod.rs b/arrow/src/row/mod.rs index abb8039cc398..ea3def6ac831 100644 --- a/arrow/src/row/mod.rs +++ b/arrow/src/row/mod.rs @@ -1225,36 +1225,11 @@ unsafe fn decode_column( } } Codec::Struct(converter, _) => { - let child_fields = match &field.data_type { - DataType::Struct(f) => f, - _ => unreachable!(), - }; - let (null_count, nulls) = fixed::decode_nulls(rows); rows.iter_mut().for_each(|row| *row = &row[1..]); let children = converter.convert_raw(rows, validate_utf8)?; - let child_data = child_fields - .iter() - .zip(&children) - .map(|(f, c)| { - let data = c.data().clone(); - match f.is_nullable() { - true => data, - false => { - assert_eq!(data.null_count(), null_count); - // Need to strip out null buffer if any as this is created - // as an artifact of the row encoding process that encodes - // nulls from the parent struct array in the children - data.into_builder() - .null_count(0) - .null_bit_buffer(None) - .build_unchecked() - } - } - }) - .collect(); - + let child_data = children.iter().map(|c| c.data().clone()).collect(); let builder = ArrayDataBuilder::new(field.data_type.clone()) .len(rows.len()) .null_count(null_count) @@ -1712,11 +1687,8 @@ mod tests { let back = converter.convert_rows(&r2).unwrap(); assert_eq!(back.len(), 1); assert_eq!(&back[0], &s2); - let back_s = as_struct_array(&back[0]); - for c in back_s.columns() { - // Children should not contain nulls - assert_eq!(c.null_count(), 0); - } + + back[0].data().validate_full().unwrap(); } #[test] @@ -2198,6 +2170,7 @@ mod tests { let back = converter.convert_rows(&rows).unwrap(); for (actual, expected) in back.iter().zip(&arrays) { + actual.data().validate_full().unwrap(); assert_eq!(actual, expected) } }