Skip to content

Commit

Permalink
Loosen nullability restrictions added in apache#3205 (apache#3226)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 30, 2022
1 parent bdfe0fd commit 2caa0a4
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 64 deletions.
2 changes: 1 addition & 1 deletion arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ mod tests {
#[test]
fn test_null_struct() {
let struct_type =
DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
let array = new_null_array(&struct_type, 9);

let a = array.as_any().downcast_ref::<StructArray>().unwrap();
Expand Down
22 changes: 8 additions & 14 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,20 +227,17 @@ impl From<Vec<(Field, ArrayRef)>> for StructArray {
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

let array_data = ArrayData::builder(DataType::Struct(field_types))
.child_data(field_values.into_iter().map(|a| a.into_data()).collect())
.len(length);
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
}
}
Expand Down Expand Up @@ -283,13 +280,6 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

Expand All @@ -298,6 +288,10 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
.child_data(field_values.into_iter().map(|a| a.into_data()).collect())
.len(length);
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
}
}
Expand Down
2 changes: 1 addition & 1 deletion arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6714,7 +6714,7 @@ mod tests {
cast_from_null_to_other(&data_type);

// Cast null from and to struct
let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
cast_from_null_to_other(&data_type);
}

Expand Down
139 changes: 122 additions & 17 deletions arrow-data/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! common attributes and operations for Arrow array.
use crate::{bit_iterator::BitSliceIterator, bitmap::Bitmap};
use arrow_buffer::bit_chunk_iterator::BitChunks;
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode};
use half::f16;
Expand Down Expand Up @@ -618,7 +619,7 @@ impl ArrayData {
/// are within the bounds of the values buffer).
///
/// See [ArrayData::validate_full] to validate fully the offset content
/// and the validitiy of utf8 data
/// and the validity of utf8 data
pub fn validate(&self) -> Result<(), ArrowError> {
// Need at least this mich space in each buffer
let len_plus_offset = self.len + self.offset;
Expand Down Expand Up @@ -961,26 +962,19 @@ impl ArrayData {
/// 3. All String data is valid UTF-8
/// 4. All dictionary offsets are valid
///
/// Does not (yet) check
/// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85)
/// Note calls `validate()` internally
/// Internally this calls:
///
/// * [`Self::validate`]
/// * [`Self::validate_nulls`]
/// * [`Self::validate_values`]
///
/// And then for each child [`ArrayData`] calls [`ArrayData::validate_full`]
///
pub fn validate_full(&self) -> Result<(), ArrowError> {
// Check all buffer sizes prior to looking at them more deeply in this function
self.validate()?;

let null_bitmap_buffer = self
.null_bitmap
.as_ref()
.map(|null_bitmap| null_bitmap.buffer_ref());

let actual_null_count = count_nulls(null_bitmap_buffer, self.offset, self.len);
if actual_null_count != self.null_count {
return Err(ArrowError::InvalidArgumentError(format!(
"null_count value ({}) doesn't match actual number of nulls in array ({})",
self.null_count, actual_null_count
)));
}

self.validate_nulls()?;
self.validate_values()?;

// validate all children recursively
Expand All @@ -999,6 +993,117 @@ impl ArrayData {
Ok(())
}

/// Validates the the null count is correct and that any
/// nullability requirements of its children are correct
pub fn validate_nulls(&self) -> Result<(), ArrowError> {
let nulls = self.null_buffer();

let actual_null_count = count_nulls(nulls, self.offset, self.len);
if actual_null_count != self.null_count {
return Err(ArrowError::InvalidArgumentError(format!(
"null_count value ({}) doesn't match actual number of nulls in array ({})",
self.null_count, actual_null_count
)));
}

// In general non-nullable children should not contain nulls, however, for certain
// types, such as StructArray and FixedSizeList, nulls in the parent take up
// space in the child. As such we permit nulls in the children in the corresponding
// positions for such types
match &self.data_type {
DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => {
if !f.is_nullable() {
self.validate_non_nullable(None, 0, &self.child_data[0])?
}
}
DataType::FixedSizeList(field, len) => {
let child = &self.child_data[0];
if !field.is_nullable() {
match nulls {
Some(nulls) => {
let element_len = *len as usize;
let mut buffer =
MutableBuffer::new_null(element_len * self.len);

for i in 0..self.len {
if !bit_util::get_bit(nulls.as_ref(), self.offset + i) {
continue;
}
for j in 0..element_len {
bit_util::set_bit(
buffer.as_mut(),
i * element_len + j,
)
}
}
let mask = buffer.into();
self.validate_non_nullable(Some(&mask), 0, child)?;
}
None => self.validate_non_nullable(None, 0, child)?,
}
}
}
DataType::Struct(fields) => {
for (field, child) in fields.iter().zip(&self.child_data) {
if !field.is_nullable() {
self.validate_non_nullable(nulls, self.offset, child)?
}
}
}
_ => {}
}

Ok(())
}

/// Verifies that `child` contains no nulls not present in `mask`
fn validate_non_nullable(
&self,
mask: Option<&Buffer>,
offset: usize,
data: &ArrayData,
) -> Result<(), ArrowError> {
let mask = match mask {
Some(mask) => mask.as_ref(),
None => return match data.null_count {
0 => Ok(()),
_ => Err(ArrowError::InvalidArgumentError(format!(
"non-nullable child of type {} contains nulls not present in parent {}",
data.data_type(),
self.data_type
))),
},
};

match data.null_buffer() {
Some(nulls) => {
let mask = BitChunks::new(mask.as_ref(), offset, data.len);
let nulls = BitChunks::new(nulls.as_ref(), data.offset, data.len);
mask
.iter()
.zip(nulls.iter())
.chain(std::iter::once((
mask.remainder_bits(),
nulls.remainder_bits(),
))).try_for_each(|(m, c)| {
if (m & !c) != 0 {
return Err(ArrowError::InvalidArgumentError(format!(
"non-nullable child of type {} contains nulls not present in parent",
data.data_type()
)))
}
Ok(())
})
}
None => Ok(()),
}
}

/// Validates the values stored within this [`ArrayData`] are valid
/// without recursing into child [`ArrayData`]
///
/// Does not (yet) check
/// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85)
pub fn validate_values(&self) -> Result<(), ArrowError> {
match &self.data_type {
DataType::Utf8 => self.validate_utf8::<i32>(),
Expand Down
35 changes: 4 additions & 31 deletions arrow/src/row/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,36 +1127,11 @@ unsafe fn decode_column(
}
}
Codec::Struct(converter, _) => {
let child_fields = match &field.data_type {
DataType::Struct(f) => f,
_ => unreachable!(),
};

let (null_count, nulls) = fixed::decode_nulls(rows);
rows.iter_mut().for_each(|row| *row = &row[1..]);
let children = converter.convert_raw(rows, validate_utf8)?;

let child_data = child_fields
.iter()
.zip(&children)
.map(|(f, c)| {
let data = c.data().clone();
match f.is_nullable() {
true => data,
false => {
assert_eq!(data.null_count(), null_count);
// Need to strip out null buffer if any as this is created
// as an artifact of the row encoding process that encodes
// nulls from the parent struct array in the children
data.into_builder()
.null_count(0)
.null_bit_buffer(None)
.build_unchecked()
}
}
})
.collect();

let child_data = children.iter().map(|c| c.data().clone()).collect();
let builder = ArrayDataBuilder::new(field.data_type.clone())
.len(rows.len())
.null_count(null_count)
Expand Down Expand Up @@ -1585,11 +1560,8 @@ mod tests {
let back = converter.convert_rows(&r2).unwrap();
assert_eq!(back.len(), 1);
assert_eq!(&back[0], &s2);
let back_s = as_struct_array(&back[0]);
for c in back_s.columns() {
// Children should not contain nulls
assert_eq!(c.null_count(), 0);
}

back[0].data().validate_full().unwrap();
}

#[test]
Expand Down Expand Up @@ -1858,6 +1830,7 @@ mod tests {

let back = converter.convert_rows(&rows).unwrap();
for (actual, expected) in back.iter().zip(&arrays) {
actual.data().validate_full().unwrap();
assert_eq!(actual, expected)
}
}
Expand Down

0 comments on commit 2caa0a4

Please sign in to comment.