Skip to content

Commit

Permalink
Loosen nullability restrictions added in #3205 (#3226) (#3244)
Browse files Browse the repository at this point in the history
* Loosen nullability restrictions added in #3205 (#3226)

* Fix tests

* More test fixes

* Yet more incorrect tests

* Review feedback
  • Loading branch information
tustvold authored Dec 5, 2022
1 parent 06e1111 commit b155461
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 57 deletions.
2 changes: 1 addition & 1 deletion arrow-array/src/array/binary_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ mod tests {

let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap());
let data_type = GenericListArray::<O>::DATA_TYPE_CONSTRUCTOR(Box::new(
Field::new("item", DataType::UInt8, false),
Field::new("item", DataType::UInt8, true),
));

// [None, Some(b"Parquet")]
Expand Down
4 changes: 3 additions & 1 deletion arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,10 @@ mod tests {

#[test]
fn test_null_struct() {
// It is possible to create a null struct containing a non-nullable child
// see https://github.com/apache/arrow-rs/pull/3244 for details
let struct_type =
DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
let array = new_null_array(&struct_type, 9);

let a = array.as_any().downcast_ref::<StructArray>().unwrap();
Expand Down
5 changes: 4 additions & 1 deletion arrow-array/src/array/string_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,11 @@ mod tests {
.unwrap();

let offsets = [0, 5, 10].map(|n| O::from_usize(n).unwrap());

// It is possible to create a null struct containing a non-nullable child
// see https://github.com/apache/arrow-rs/pull/3244 for details
let data_type = GenericListArray::<O>::DATA_TYPE_CONSTRUCTOR(Box::new(
Field::new("item", DataType::UInt8, false),
Field::new("item", DataType::UInt8, true),
));

// [None, Some(b"Parquet")]
Expand Down
30 changes: 13 additions & 17 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,17 @@ impl From<Vec<(Field, ArrayRef)>> for StructArray {
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

let array_data = ArrayData::builder(DataType::Struct(field_types))
.child_data(field_values.into_iter().map(|a| a.into_data()).collect())
.len(length);
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
}
}
Expand Down Expand Up @@ -283,13 +280,6 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

Expand All @@ -298,6 +288,10 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
.child_data(field_values.into_iter().map(|a| a.into_data()).collect())
.len(length);
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
}
}
Expand Down Expand Up @@ -470,8 +464,8 @@ mod tests {
.unwrap();

let field_types = vec![
Field::new("a", DataType::Boolean, false),
Field::new("b", DataType::Int32, false),
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Int32, true),
];
let struct_array_data = ArrayData::builder(DataType::Struct(field_types))
.len(5)
Expand Down Expand Up @@ -568,7 +562,9 @@ mod tests {
}

#[test]
#[should_panic(expected = "non-nullable field cannot have null values")]
#[should_panic(
expected = "non-nullable child of type Int32 contains nulls not present in parent Struct"
)]
fn test_struct_array_from_mismatched_nullability() {
drop(StructArray::from(vec![(
Field::new("c", DataType::Int32, false),
Expand Down
3 changes: 2 additions & 1 deletion arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6594,7 +6594,8 @@ mod tests {
cast_from_null_to_other(&data_type);

// Cast null from and to struct
let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
let data_type =
DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
cast_from_null_to_other(&data_type);
}

Expand Down
103 changes: 102 additions & 1 deletion arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! common attributes and operations for Arrow array.
use crate::{bit_iterator::BitSliceIterator, bitmap::Bitmap};
use arrow_buffer::bit_chunk_iterator::BitChunks;
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode};
use half::f16;
Expand Down Expand Up @@ -975,6 +976,7 @@ impl ArrayData {
/// see [`Self::validate_full`]
pub fn validate_data(&self) -> Result<(), ArrowError> {
self.validate()?;

self.validate_nulls()?;
self.validate_values()?;
Ok(())
Expand All @@ -1001,7 +1003,13 @@ impl ArrayData {
Ok(())
}

/// Validates the the null count is correct
/// Validates the values stored within this [`ArrayData`] are valid
/// without recursing into child [`ArrayData`]
///
/// Does not (yet) check
/// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85)
/// Validates the the null count is correct and that any
/// nullability requirements of its children are correct
pub fn validate_nulls(&self) -> Result<(), ArrowError> {
let nulls = self.null_buffer();

Expand All @@ -1012,9 +1020,102 @@ impl ArrayData {
self.null_count, actual_null_count
)));
}

// In general non-nullable children should not contain nulls, however, for certain
// types, such as StructArray and FixedSizeList, nulls in the parent take up
// space in the child. As such we permit nulls in the children in the corresponding
// positions for such types
match &self.data_type {
DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => {
if !f.is_nullable() {
self.validate_non_nullable(None, 0, &self.child_data[0])?
}
}
DataType::FixedSizeList(field, len) => {
let child = &self.child_data[0];
if !field.is_nullable() {
match nulls {
Some(nulls) => {
let element_len = *len as usize;
let mut buffer =
MutableBuffer::new_null(element_len * self.len);

// Expand each bit within `null_mask` into `element_len`
// bits, constructing the implicit mask of the child elements
for i in 0..self.len {
if !bit_util::get_bit(nulls.as_ref(), self.offset + i) {
continue;
}
for j in 0..element_len {
bit_util::set_bit(
buffer.as_mut(),
i * element_len + j,
)
}
}
let mask = buffer.into();
self.validate_non_nullable(Some(&mask), 0, child)?;
}
None => self.validate_non_nullable(None, 0, child)?,
}
}
}
DataType::Struct(fields) => {
for (field, child) in fields.iter().zip(&self.child_data) {
if !field.is_nullable() {
self.validate_non_nullable(nulls, self.offset, child)?
}
}
}
_ => {}
}

Ok(())
}

/// Verifies that `child` contains no nulls not present in `mask`
fn validate_non_nullable(
&self,
mask: Option<&Buffer>,
offset: usize,
data: &ArrayData,
) -> Result<(), ArrowError> {
let mask = match mask {
Some(mask) => mask.as_ref(),
None => return match data.null_count {
0 => Ok(()),
_ => Err(ArrowError::InvalidArgumentError(format!(
"non-nullable child of type {} contains nulls not present in parent {}",
data.data_type(),
self.data_type
))),
},
};

match data.null_buffer() {
Some(nulls) => {
let mask = BitChunks::new(mask, offset, data.len);
let nulls = BitChunks::new(nulls.as_ref(), data.offset, data.len);
mask
.iter()
.zip(nulls.iter())
.chain(std::iter::once((
mask.remainder_bits(),
nulls.remainder_bits(),
))).try_for_each(|(m, c)| {
if (m & !c) != 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"non-nullable child of type {} contains nulls not present in parent",
data.data_type()
)))
}
Ok(())
})
}
None => Ok(()),
}
}

/// Validates the values stored within this [`ArrayData`] are valid
/// without recursing into child [`ArrayData`]
///
Expand Down
4 changes: 2 additions & 2 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1603,7 +1603,7 @@ mod tests {
let list_data_type = DataType::$list_data_type(Box::new(Field::new(
"item",
DataType::Int32,
false,
true,
)));
let list_data = ArrayData::builder(list_data_type.clone())
.len(4)
Expand Down Expand Up @@ -1676,7 +1676,7 @@ mod tests {
let list_data_type = DataType::$list_data_type(Box::new(Field::new(
"item",
DataType::Int32,
false,
true,
)));
let list_data = ArrayData::builder(list_data_type.clone())
.len(4)
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/compute/kernels/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ mod tests {
.unwrap();

let field_types = vec![
Field::new("a", DataType::Boolean, false),
Field::new("b", DataType::Int32, false),
Field::new("a", DataType::Boolean, true),
Field::new("b", DataType::Int32, true),
];
let struct_array_data = ArrayData::builder(DataType::Struct(field_types))
.len(5)
Expand Down
35 changes: 4 additions & 31 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,36 +1225,11 @@ unsafe fn decode_column(
}
}
Codec::Struct(converter, _) => {
let child_fields = match &field.data_type {
DataType::Struct(f) => f,
_ => unreachable!(),
};

let (null_count, nulls) = fixed::decode_nulls(rows);
rows.iter_mut().for_each(|row| *row = &row[1..]);
let children = converter.convert_raw(rows, validate_utf8)?;

let child_data = child_fields
.iter()
.zip(&children)
.map(|(f, c)| {
let data = c.data().clone();
match f.is_nullable() {
true => data,
false => {
assert_eq!(data.null_count(), null_count);
// Need to strip out null buffer if any as this is created
// as an artifact of the row encoding process that encodes
// nulls from the parent struct array in the children
data.into_builder()
.null_count(0)
.null_bit_buffer(None)
.build_unchecked()
}
}
})
.collect();

let child_data = children.iter().map(|c| c.data().clone()).collect();
let builder = ArrayDataBuilder::new(field.data_type.clone())
.len(rows.len())
.null_count(null_count)
Expand Down Expand Up @@ -1712,11 +1687,8 @@ mod tests {
let back = converter.convert_rows(&r2).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(&back[0], &s2);
let back_s = as_struct_array(&back[0]);
for c in back_s.columns() {
// Children should not contain nulls
assert_eq!(c.null_count(), 0);
}

back[0].data().validate_full().unwrap();
}

#[test]
Expand Down Expand Up @@ -2198,6 +2170,7 @@ mod tests {

let back = converter.convert_rows(&rows).unwrap();
for (actual, expected) in back.iter().zip(&arrays) {
actual.data().validate_full().unwrap();
assert_eq!(actual, expected)
}
}
Expand Down

0 comments on commit b155461

Please sign in to comment.