From 8ab18fd6be20ebb77157839c35a7c359dc25a58e Mon Sep 17 00:00:00 2001 From: Alex Wilcoxson Date: Fri, 20 Sep 2024 05:02:47 -0500 Subject: [PATCH] =?UTF-8?q?fix:=20don't=20panic=20in=20IPC=20reader=20if?= =?UTF-8?q?=20struct=20child=20arrays=20have=20different=20=E2=80=A6=20(#6?= =?UTF-8?q?417)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: don't panic in IPC reader if struct child arrays have different lengths * fix: clippy * test: add ipc read invalid struct test --- arrow-ipc/src/reader.rs | 54 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 2b1d09dc9588..3e07c95afb23 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -149,7 +149,7 @@ fn create_array( // still work for struct_field in struct_fields { let child = create_array(reader, struct_field, variadic_counts, require_alignment)?; - struct_arrays.push((struct_field.clone(), child)); + struct_arrays.push(child); } let null_count = struct_node.null_count() as usize; let struct_array = if struct_arrays.is_empty() { @@ -162,9 +162,11 @@ fn create_array( ) } else if null_count > 0 { // create struct array from fields, arrays and null data - StructArray::from((struct_arrays, null_buffer)) + let len = struct_node.length() as usize; + let nulls = BooleanBuffer::new(null_buffer, 0, len).into(); + StructArray::try_new(struct_fields.clone(), struct_arrays, Some(nulls))? } else { - StructArray::from(struct_arrays) + StructArray::try_new(struct_fields.clone(), struct_arrays, None)? }; Ok(Arc::new(struct_array)) } @@ -2235,4 +2237,50 @@ mod tests { assert_eq!(batch, roundtrip_batch); } + + #[test] + fn test_invalid_struct_array_ipc_read_errors() { + let a_field = Field::new("a", DataType::Int32, false); + let b_field = Field::new("b", DataType::Int32, false); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "s", + vec![a_field.clone(), b_field.clone()], + false, + )])); + + let a_array_data = ArrayData::builder(a_field.data_type().clone()) + .len(4) + .add_buffer(Buffer::from_slice_ref([1, 2, 3, 4])) + .build() + .unwrap(); + let b_array_data = ArrayData::builder(b_field.data_type().clone()) + .len(3) + .add_buffer(Buffer::from_slice_ref([5, 6, 7])) + .build() + .unwrap(); + + let struct_data_type = schema.field(0).data_type(); + + let invalid_struct_arr = unsafe { + make_array( + ArrayData::builder(struct_data_type.clone()) + .len(4) + .add_child_data(a_array_data) + .add_child_data(b_array_data) + .build_unchecked(), + ) + }; + + let batch = RecordBatch::try_new(schema.clone(), vec![invalid_struct_arr]).unwrap(); + + let mut buf = Vec::new(); + let mut writer = crate::writer::FileWriter::try_new(&mut buf, schema.as_ref()).unwrap(); + writer.write(&batch).unwrap(); + writer.finish().unwrap(); + + let mut reader = FileReader::try_new(std::io::Cursor::new(buf), None).unwrap(); + let err = reader.next().unwrap().unwrap_err(); + assert!(matches!(err, ArrowError::InvalidArgumentError(_))); + } }