Skip to content

Commit

Permalink
Ensure StructArrays check nullability of fields (#3205)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey authored Nov 29, 2022
1 parent 4926bad commit 1d6b5ab
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 48 deletions.
2 changes: 1 addition & 1 deletion arrow-array/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ mod tests {
#[test]
fn test_null_struct() {
let struct_type =
DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
let array = new_null_array(&struct_type, 9);

let a = array.as_any().downcast_ref::<StructArray>().unwrap();
Expand Down
107 changes: 77 additions & 30 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl TryFrom<Vec<(&str, ArrayRef)>> for StructArray {
child_null_buffer.bit_slice(child_datum_offset, child_datum_len)
});
} else if null.is_some() {
// when one of the fields has no nulls, them there is no null in the array
// when one of the fields has no nulls, then there is no null in the array
null = None;
}
}
Expand Down Expand Up @@ -212,20 +212,30 @@ impl From<Vec<(Field, ArrayRef)>> for StructArray {
fn from(v: Vec<(Field, ArrayRef)>) -> Self {
let (field_types, field_values): (Vec<_>, Vec<_>) = v.into_iter().unzip();

// Check the length of the child arrays
let length = field_values[0].len();
for i in 1..field_values.len() {
assert_eq!(
length,
field_values[i].len(),
"all child arrays of a StructArray must have the same length"
);
assert_eq!(
field_types[i].data_type(),
field_values[i].data().data_type(),
"the field data types must match the array data in a StructArray"
)
}
let length = field_values.get(0).map(|a| a.len()).unwrap_or(0);
field_types.iter().zip(field_values.iter()).for_each(
|(field_type, field_value)| {
// Check the length of the child arrays
assert_eq!(
length,
field_value.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field_type.data_type(),
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

let array_data = ArrayData::builder(DataType::Struct(field_types))
.child_data(field_values.into_iter().map(|a| a.into_data()).collect())
Expand Down Expand Up @@ -258,20 +268,30 @@ impl From<(Vec<(Field, ArrayRef)>, Buffer)> for StructArray {
fn from(pair: (Vec<(Field, ArrayRef)>, Buffer)) -> Self {
let (field_types, field_values): (Vec<_>, Vec<_>) = pair.0.into_iter().unzip();

// Check the length of the child arrays
let length = field_values[0].len();
for i in 1..field_values.len() {
assert_eq!(
length,
field_values[i].len(),
"all child arrays of a StructArray must have the same length"
);
assert_eq!(
field_types[i].data_type(),
field_values[i].data().data_type(),
"the field data types must match the array data in a StructArray"
)
}
let length = field_values.get(0).map(|a| a.len()).unwrap_or(0);
field_types.iter().zip(field_values.iter()).for_each(
|(field_type, field_value)| {
// Check the length of the child arrays
assert_eq!(
length,
field_value.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field_type.data_type(),
field_value.data().data_type(),
"the field data types must match the array data in a StructArray"
);
// Check nullability of child arrays
if !field_type.is_nullable() {
assert!(
field_value.null_count() == 0,
"non-nullable field cannot have null values"
);
}
},
);

let array_data = ArrayData::builder(DataType::Struct(field_types))
.null_bit_buffer(Some(pair.1))
Expand Down Expand Up @@ -408,7 +428,19 @@ mod tests {
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
)]
fn test_struct_array_from_mismatched_types() {
fn test_struct_array_from_mismatched_types_single() {
drop(StructArray::from(vec![(
Field::new("b", DataType::Int16, false),
Arc::new(BooleanArray::from(vec![false, false, true, true]))
as Arc<dyn Array>,
)]));
}

#[test]
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
)]
fn test_struct_array_from_mismatched_types_multiple() {
drop(StructArray::from(vec![
(
Field::new("b", DataType::Int16, false),
Expand Down Expand Up @@ -528,4 +560,19 @@ mod tests {
),
]));
}

#[test]
fn test_struct_array_from_empty() {
let sa = StructArray::from(vec![]);
assert!(sa.is_empty())
}

#[test]
#[should_panic(expected = "non-nullable field cannot have null values")]
fn test_struct_array_from_mismatched_nullability() {
drop(StructArray::from(vec![(
Field::new("c", DataType::Int32, false),
Arc::new(Int32Array::from(vec![Some(42), None, Some(19)])) as ArrayRef,
)]));
}
}
15 changes: 7 additions & 8 deletions arrow-array/src/builder/map_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ use std::sync::Arc;
///
/// let string_builder = builder.keys();
/// string_builder.append_value("joe");
/// string_builder.append_null();
/// string_builder.append_null();
/// string_builder.append_value("n1");
/// string_builder.append_value("n2");
/// string_builder.append_value("mark");
///
/// let int_builder = builder.values();
Expand All @@ -58,7 +58,7 @@ use std::sync::Arc;
/// );
/// assert_eq!(
/// *arr.keys(),
/// StringArray::from(vec![Some("joe"), None, None, Some("mark")])
/// StringArray::from(vec![Some("joe"), Some("n1"), Some("n2"), Some("mark")])
/// );
/// ```
#[derive(Debug)]
Expand Down Expand Up @@ -286,8 +286,8 @@ mod tests {

let string_builder = builder.keys();
string_builder.append_value("joe");
string_builder.append_null();
string_builder.append_null();
string_builder.append_value("n1");
string_builder.append_value("n2");
string_builder.append_value("mark");

let int_builder = builder.values();
Expand All @@ -312,9 +312,8 @@ mod tests {

let expected_string_data = ArrayData::builder(DataType::Utf8)
.len(4)
.null_bit_buffer(Some(Buffer::from(&[9_u8])))
.add_buffer(Buffer::from_slice_ref([0, 3, 3, 3, 7]))
.add_buffer(Buffer::from_slice_ref(b"joemark"))
.add_buffer(Buffer::from_slice_ref([0, 3, 5, 7, 11]))
.add_buffer(Buffer::from_slice_ref(b"joen1n2mark"))
.build()
.unwrap();

Expand Down
3 changes: 1 addition & 2 deletions arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6714,8 +6714,7 @@ mod tests {
cast_from_null_to_other(&data_type);

// Cast null from and to struct
let data_type =
DataType::Struct(vec![Field::new("data", DataType::Int64, false)]);
let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, true)]);
cast_from_null_to_other(&data_type);
}

Expand Down
2 changes: 1 addition & 1 deletion arrow-ipc/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1799,7 +1799,7 @@ mod tests {
Arc::new(strings) as ArrayRef,
),
(
Field::new("c", DataType::Int32, false),
Field::new("c", DataType::Int32, true),
Arc::new(ints) as ArrayRef,
),
]);
Expand Down
8 changes: 4 additions & 4 deletions arrow-json/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ mod tests {
Field::new(
"c1",
DataType::Struct(vec![
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Field::new(
"c12",
DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]),
Expand All @@ -1083,7 +1083,7 @@ mod tests {

let c1 = StructArray::from(vec![
(
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Arc::new(Int32Array::from(vec![Some(1), None, Some(5)])) as ArrayRef,
),
(
Expand Down Expand Up @@ -1230,7 +1230,7 @@ mod tests {
DataType::List(Box::new(Field::new(
"s",
DataType::Struct(vec![
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Field::new(
"c12",
DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]),
Expand All @@ -1246,7 +1246,7 @@ mod tests {

let struct_values = StructArray::from(vec![
(
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Arc::new(Int32Array::from(vec![Some(1), None, Some(5)])) as ArrayRef,
),
(
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/util/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ mod tests {
Field::new(
"c1",
DataType::Struct(vec![
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Field::new(
"c12",
DataType::Struct(vec![Field::new("c121", DataType::Utf8, false)]),
Expand All @@ -727,7 +727,7 @@ mod tests {

let c1 = StructArray::from(vec![
(
Field::new("c11", DataType::Int32, false),
Field::new("c11", DataType::Int32, true),
Arc::new(Int32Array::from(vec![Some(1), None, Some(5)])) as ArrayRef,
),
(
Expand Down

0 comments on commit 1d6b5ab

Please sign in to comment.