From b7248497a43992a6f8da41b25829766b0867891c Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 20 Nov 2023 17:20:10 +0100 Subject: [PATCH] Improve C Data Interface and Add Integration Testing Entrypoints (#5080) * Add C Data Interface integration testing entrypoints * Allow importing FFI_ArrowArray with existing datatype * Clippy * Use ptr::write * Fix null_count for Null type * Use new from_raw() APIs * Address some review comments. * Add unsafe markers * Try to fix CI * Revamp ArrowFile --- arrow-data/src/ffi.rs | 8 +- arrow-integration-testing/Cargo.toml | 5 +- arrow-integration-testing/README.md | 2 +- .../src/bin/arrow-json-integration-test.rs | 49 +--- .../integration_test.rs | 17 +- arrow-integration-testing/src/lib.rs | 228 ++++++++++++++++-- arrow-schema/src/error.rs | 6 + arrow-schema/src/ffi.rs | 10 +- arrow/src/array/ffi.rs | 2 +- arrow/src/ffi.rs | 161 ++++++++----- arrow/src/ffi_stream.rs | 6 +- arrow/src/pyarrow.rs | 6 +- 12 files changed, 363 insertions(+), 137 deletions(-) diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs index 2b4d52601286..589f7dac6d19 100644 --- a/arrow-data/src/ffi.rs +++ b/arrow-data/src/ffi.rs @@ -168,6 +168,12 @@ impl FFI_ArrowArray { .collect::>(); let n_children = children.len() as i64; + // As in the IPC format, emit null_count = length for Null type + let null_count = match data.data_type() { + DataType::Null => data.len(), + _ => data.null_count(), + }; + // create the private data owning everything. // any other data must be added here, e.g. via a struct, to track lifetime. let mut private_data = Box::new(ArrayPrivateData { @@ -179,7 +185,7 @@ impl FFI_ArrowArray { Self { length: data.len() as i64, - null_count: data.null_count() as i64, + null_count: null_count as i64, offset: data.offset() as i64, n_buffers, n_children, diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index 86c2cb27d297..c29860f09d64 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -27,11 +27,14 @@ edition = { workspace = true } publish = false rust-version = { workspace = true } +[lib] +crate-type = ["lib", "cdylib"] + [features] logging = ["tracing-subscriber"] [dependencies] -arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json"] } +arrow = { path = "../arrow", default-features = false, features = ["test_utils", "ipc", "ipc_compression", "json", "ffi"] } arrow-flight = { path = "../arrow-flight", default-features = false } arrow-buffer = { path = "../arrow-buffer", default-features = false } arrow-integration-test = { path = "../arrow-integration-test", default-features = false } diff --git a/arrow-integration-testing/README.md b/arrow-integration-testing/README.md index e82591e6b139..dcf39c27fbc5 100644 --- a/arrow-integration-testing/README.md +++ b/arrow-integration-testing/README.md @@ -48,7 +48,7 @@ ln -s arrow/rust ```shell cd arrow -pip install -e dev/archery[docker] +pip install -e dev/archery[integration] ``` ### Build the C++ binaries: diff --git a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs index 187d987a5a0a..9f1abb16a668 100644 --- a/arrow-integration-testing/src/bin/arrow-json-integration-test.rs +++ b/arrow-integration-testing/src/bin/arrow-json-integration-test.rs @@ -15,16 +15,13 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; -use arrow::datatypes::{Fields, Schema}; use arrow::error::{ArrowError, Result}; use arrow::ipc::reader::FileReader; use arrow::ipc::writer::FileWriter; use arrow_integration_test::*; -use arrow_integration_testing::read_json_file; +use arrow_integration_testing::{canonicalize_schema, open_json_file}; use clap::Parser; use std::fs::File; -use std::sync::Arc; #[derive(clap::ValueEnum, Debug, Clone)] #[clap(rename_all = "SCREAMING_SNAKE_CASE")] @@ -66,12 +63,12 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()> eprintln!("Converting {json_name} to {arrow_name}"); } - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; let arrow_file = File::create(arrow_name)?; let mut writer = FileWriter::try_new(arrow_file, &json_file.schema)?; - for b in json_file.batches { + for b in json_file.read_batches()? { writer.write(&b)?; } @@ -113,49 +110,13 @@ fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> Ok(()) } -fn canonicalize_schema(schema: &Schema) -> Schema { - let fields = schema - .fields() - .iter() - .map(|field| match field.data_type() { - DataType::Map(child_field, sorted) => match child_field.data_type() { - DataType::Struct(fields) if fields.len() == 2 => { - let first_field = fields.get(0).unwrap(); - let key_field = - Arc::new(Field::new("key", first_field.data_type().clone(), false)); - let second_field = fields.get(1).unwrap(); - let value_field = Arc::new(Field::new( - "value", - second_field.data_type().clone(), - second_field.is_nullable(), - )); - - let fields = Fields::from([key_field, value_field]); - let struct_type = DataType::Struct(fields); - let child_field = Field::new("entries", struct_type, false); - - Arc::new(Field::new( - field.name().as_str(), - DataType::Map(Arc::new(child_field), *sorted), - field.is_nullable(), - )) - } - _ => panic!("The child field of Map type should be Struct type with 2 fields."), - }, - _ => field.clone(), - }) - .collect::(); - - Schema::new(fields).with_metadata(schema.metadata().clone()) -} - fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { if verbose { eprintln!("Validating {arrow_name} and {json_name}"); } // open JSON file - let json_file = read_json_file(json_name)?; + let json_file = open_json_file(json_name)?; // open Arrow file let arrow_file = File::open(arrow_name)?; @@ -170,7 +131,7 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> { ))); } - let json_batches = &json_file.batches; + let json_batches = json_file.read_batches()?; // compare number of batches assert!( diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index 81cc4bbe8ed2..c6b5a72ca6e2 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{read_json_file, ArrowFile}; +use crate::open_json_file; use std::collections::HashMap; use arrow::{ @@ -45,23 +45,16 @@ pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { let client = FlightServiceClient::connect(url).await?; - let ArrowFile { - schema, batches, .. - } = read_json_file(path)?; + let json_file = open_json_file(path)?; - let schema = Arc::new(schema); + let batches = json_file.read_batches()?; + let schema = Arc::new(json_file.schema); let mut descriptor = FlightDescriptor::default(); descriptor.set_type(DescriptorType::Path); descriptor.path = vec![path.to_string()]; - upload_data( - client.clone(), - schema.clone(), - descriptor.clone(), - batches.clone(), - ) - .await?; + upload_data(client.clone(), schema, descriptor.clone(), batches.clone()).await?; verify_data(client, descriptor, &batches).await?; Ok(()) diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs index 2d76be3495c8..553e69b0a1a0 100644 --- a/arrow-integration-testing/src/lib.rs +++ b/arrow-integration-testing/src/lib.rs @@ -19,14 +19,20 @@ use serde_json::Value; -use arrow::datatypes::Schema; -use arrow::error::Result; +use arrow::array::{Array, StructArray}; +use arrow::datatypes::{DataType, Field, Fields, Schema}; +use arrow::error::{ArrowError, Result}; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::arrow_test_data; use arrow_integration_test::*; use std::collections::HashMap; +use std::ffi::{c_int, CStr, CString}; use std::fs::File; use std::io::BufReader; +use std::iter::zip; +use std::ptr; +use std::sync::Arc; /// The expected username for the basic auth integration test. pub const AUTH_USERNAME: &str = "arrow"; @@ -40,11 +46,68 @@ pub struct ArrowFile { pub schema: Schema, // we can evolve this into a concrete Arrow type // this is temporarily not being read from - pub _dictionaries: HashMap, - pub batches: Vec, + dictionaries: HashMap, + arrow_json: Value, } -pub fn read_json_file(json_name: &str) -> Result { +impl ArrowFile { + pub fn read_batch(&self, batch_num: usize) -> Result { + let b = self.arrow_json["batches"].get(batch_num).unwrap(); + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + } + + pub fn read_batches(&self) -> Result> { + self.arrow_json["batches"] + .as_array() + .unwrap() + .iter() + .map(|b| { + let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); + record_batch_from_json(&self.schema, json_batch, Some(&self.dictionaries)) + }) + .collect() + } +} + +// Canonicalize the names of map fields in a schema +pub fn canonicalize_schema(schema: &Schema) -> Schema { + let fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Map(child_field, sorted) => match child_field.data_type() { + DataType::Struct(fields) if fields.len() == 2 => { + let first_field = fields.get(0).unwrap(); + let key_field = + Arc::new(Field::new("key", first_field.data_type().clone(), false)); + let second_field = fields.get(1).unwrap(); + let value_field = Arc::new(Field::new( + "value", + second_field.data_type().clone(), + second_field.is_nullable(), + )); + + let fields = Fields::from([key_field, value_field]); + let struct_type = DataType::Struct(fields); + let child_field = Field::new("entries", struct_type, false); + + Arc::new(Field::new( + field.name().as_str(), + DataType::Map(Arc::new(child_field), *sorted), + field.is_nullable(), + )) + } + _ => panic!("The child field of Map type should be Struct type with 2 fields."), + }, + _ => field.clone(), + }) + .collect::(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +pub fn open_json_file(json_name: &str) -> Result { let json_file = File::open(json_name)?; let reader = BufReader::new(json_file); let arrow_json: Value = serde_json::from_reader(reader).unwrap(); @@ -62,17 +125,10 @@ pub fn read_json_file(json_name: &str) -> Result { dictionaries.insert(json_dict.id, json_dict); } } - - let mut batches = vec![]; - for b in arrow_json["batches"].as_array().unwrap() { - let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap(); - let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?; - batches.push(batch); - } Ok(ArrowFile { schema, - _dictionaries: dictionaries, - batches, + dictionaries, + arrow_json, }) } @@ -100,3 +156,147 @@ pub fn read_gzip_json(version: &str, path: &str) -> ArrowJson { let arrow_json: ArrowJson = serde_json::from_str(&s).unwrap(); arrow_json } + +// +// C Data Integration entrypoints +// + +fn cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let f = open_json_file(json_name.to_str()?)?; + let c_schema = FFI_ArrowSchema::try_from(&f.schema)?; + // Move exported schema into output struct + unsafe { ptr::write(out, c_schema) }; + Ok(()) +} + +fn cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let b = open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let a = StructArray::from(b).into_data(); + let c_array = FFI_ArrowArray::new(&a); + // Move exported array into output struct + unsafe { ptr::write(out, c_array) }; + Ok(()) +} + +fn cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_schema = open_json_file(json_name.to_str()?)?.schema; + + // The source ArrowSchema will be released when this is dropped + let imported_schema = unsafe { FFI_ArrowSchema::from_raw(c_schema) }; + let imported_schema = Schema::try_from(&imported_schema)?; + + // compare schemas + if canonicalize_schema(&json_schema) != canonicalize_schema(&imported_schema) { + return Err(ArrowError::ComputeError(format!( + "Schemas do not match.\n- JSON: {:?}\n- Imported: {:?}", + json_schema, imported_schema + ))); + } + Ok(()) +} + +fn compare_batches(a: &RecordBatch, b: &RecordBatch) -> Result<()> { + if a.num_columns() != b.num_columns() { + return Err(ArrowError::InvalidArgumentError( + "batches do not have the same number of columns".to_string(), + )); + } + for (a_column, b_column) in zip(a.columns(), b.columns()) { + if a_column != b_column { + return Err(ArrowError::InvalidArgumentError( + "batch columns are not the same".to_string(), + )); + } + } + Ok(()) +} + +fn cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> Result<()> { + let json_name = unsafe { CStr::from_ptr(c_json_name) }; + let json_batch = + open_json_file(json_name.to_str()?)?.read_batch(batch_num.try_into().unwrap())?; + let schema = json_batch.schema(); + + let data_type_for_import = DataType::Struct(schema.fields.clone()); + let imported_array = unsafe { FFI_ArrowArray::from_raw(c_array) }; + let imported_array = unsafe { from_ffi_and_data_type(imported_array, data_type_for_import) }?; + imported_array.validate_full()?; + let imported_batch = RecordBatch::from(StructArray::from(imported_array)); + + compare_batches(&json_batch, &imported_batch) +} + +// If Result is an error, then export a const char* to its string display, otherwise NULL +fn result_to_c_error(result: &std::result::Result) -> *mut i8 { + match result { + Ok(_) => ptr::null_mut(), + Err(e) => CString::new(format!("{}", e)).unwrap().into_raw(), + } +} + +/// Release a const char* exported by result_to_c_error() +/// +/// # Safety +/// +/// The pointer is assumed to have been obtained using CString::into_raw. +#[no_mangle] +pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut i8) { + if !c_error.is_null() { + drop(unsafe { CString::from_raw(c_error) }); + } +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json( + c_json_name: *const i8, + out: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_export_schema_from_json(c_json_name, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json( + c_json_name: *const i8, + c_schema: *mut FFI_ArrowSchema, +) -> *mut i8 { + let r = cdata_integration_import_schema_and_compare_to_json(c_json_name, c_schema); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json( + c_json_name: *const i8, + batch_num: c_int, + out: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_export_batch_from_json(c_json_name, batch_num, out); + result_to_c_error(&r) +} + +#[no_mangle] +pub extern "C" fn arrow_rs_cdata_integration_import_batch_and_compare_to_json( + c_json_name: *const i8, + batch_num: c_int, + c_array: *mut FFI_ArrowArray, +) -> *mut i8 { + let r = cdata_integration_import_batch_and_compare_to_json(c_json_name, batch_num, c_array); + result_to_c_error(&r) +} diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs index 8ea533db89af..b7bf8d6e12a6 100644 --- a/arrow-schema/src/error.rs +++ b/arrow-schema/src/error.rs @@ -58,6 +58,12 @@ impl From for ArrowError { } } +impl From for ArrowError { + fn from(error: std::str::Utf8Error) -> Self { + ArrowError::ParseError(error.to_string()) + } +} + impl From for ArrowError { fn from(error: std::string::FromUtf8Error) -> Self { ArrowError::ParseError(error.to_string()) diff --git a/arrow-schema/src/ffi.rs b/arrow-schema/src/ffi.rs index b4d10b814a5d..8a18c77ea291 100644 --- a/arrow-schema/src/ffi.rs +++ b/arrow-schema/src/ffi.rs @@ -34,7 +34,9 @@ //! assert_eq!(schema, back); //! ``` -use crate::{ArrowError, DataType, Field, FieldRef, Schema, TimeUnit, UnionFields, UnionMode}; +use crate::{ + ArrowError, DataType, Field, FieldRef, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, +}; use std::sync::Arc; use std::{ collections::HashMap, @@ -402,6 +404,9 @@ impl TryFrom<&FFI_ArrowSchema> for DataType { "tDm" => DataType::Duration(TimeUnit::Millisecond), "tDu" => DataType::Duration(TimeUnit::Microsecond), "tDn" => DataType::Duration(TimeUnit::Nanosecond), + "tiM" => DataType::Interval(IntervalUnit::YearMonth), + "tiD" => DataType::Interval(IntervalUnit::DayTime), + "tin" => DataType::Interval(IntervalUnit::MonthDayNano), "+l" => { let c_child = c_schema.child(0); DataType::List(Arc::new(Field::try_from(c_child)?)) @@ -669,6 +674,9 @@ fn get_format_string(dtype: &DataType) -> Result { DataType::Duration(TimeUnit::Millisecond) => Ok("tDm".to_string()), DataType::Duration(TimeUnit::Microsecond) => Ok("tDu".to_string()), DataType::Duration(TimeUnit::Nanosecond) => Ok("tDn".to_string()), + DataType::Interval(IntervalUnit::YearMonth) => Ok("tiM".to_string()), + DataType::Interval(IntervalUnit::DayTime) => Ok("tiD".to_string()), + DataType::Interval(IntervalUnit::MonthDayNano) => Ok("tin".to_string()), DataType::List(_) => Ok("+l".to_string()), DataType::LargeList(_) => Ok("+L".to_string()), DataType::Struct(_) => Ok("+s".to_string()), diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index e05c256d0128..d4d95a6e1770 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -70,7 +70,7 @@ mod tests { let schema = FFI_ArrowSchema::try_from(expected.data_type())?; // simulate an external consumer by being the consumer - let result = &from_ffi(array, &schema)?; + let result = &unsafe { from_ffi(array, &schema) }?; assert_eq!(result, expected); Ok(()) diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index c13d4c6e5dff..31388bf99358 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -43,7 +43,7 @@ //! let (out_array, out_schema) = to_ffi(&data)?; //! //! // import it -//! let data = from_ffi(out_array, &out_schema)?; +//! let data = unsafe { from_ffi(out_array, &out_schema) }?; //! let array = Int32Array::from(data); //! //! // perform some operation @@ -80,7 +80,7 @@ //! let mut schema = FFI_ArrowSchema::empty(); //! let mut array = FFI_ArrowArray::empty(); //! foreign.export_to_c(addr_of_mut!(array), addr_of_mut!(schema)); -//! Ok(make_array(from_ffi(array, &schema)?)) +//! Ok(make_array(unsafe { from_ffi(array, &schema) }?)) //! } //! ``` @@ -108,6 +108,7 @@ use std::{mem::size_of, ptr::NonNull, sync::Arc}; pub use arrow_data::ffi::FFI_ArrowArray; pub use arrow_schema::ffi::{FFI_ArrowSchema, Flags}; + use arrow_schema::UnionMode; use crate::array::{layout, ArrayData}; @@ -233,32 +234,53 @@ pub fn to_ffi(data: &ArrayData) -> Result<(FFI_ArrowArray, FFI_ArrowSchema)> { /// # Safety /// /// This struct assumes that the incoming data agrees with the C data interface. -pub fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { +pub unsafe fn from_ffi(array: FFI_ArrowArray, schema: &FFI_ArrowSchema) -> Result { + let dt = DataType::try_from(schema)?; let array = Arc::new(array); - let tmp = ArrowArray { + let tmp = ImportedArrowArray { array: &array, - schema, + data_type: dt, + owner: &array, + }; + tmp.consume() +} + +/// Import [ArrayData] from the C Data Interface +/// +/// # Safety +/// +/// This struct assumes that the incoming data agrees with the C data interface. +pub unsafe fn from_ffi_and_data_type( + array: FFI_ArrowArray, + data_type: DataType, +) -> Result { + let array = Arc::new(array); + let tmp = ImportedArrowArray { + array: &array, + data_type, owner: &array, }; tmp.consume() } #[derive(Debug)] -struct ArrowArray<'a> { +struct ImportedArrowArray<'a> { array: &'a FFI_ArrowArray, - schema: &'a FFI_ArrowSchema, + data_type: DataType, owner: &'a Arc, } -impl<'a> ArrowArray<'a> { +impl<'a> ImportedArrowArray<'a> { fn consume(self) -> Result { - let dt = DataType::try_from(self.schema)?; let len = self.array.len(); let offset = self.array.offset(); - let null_count = self.array.null_count(); + let null_count = match &self.data_type { + DataType::Null => 0, + _ => self.array.null_count(), + }; - let data_layout = layout(&dt); - let buffers = self.buffers(data_layout.can_contain_null_mask, &dt)?; + let data_layout = layout(&self.data_type); + let buffers = self.buffers(data_layout.can_contain_null_mask)?; let null_bit_buffer = if data_layout.can_contain_null_mask { self.null_bit_buffer() @@ -266,14 +288,9 @@ impl<'a> ArrowArray<'a> { None }; - let mut child_data = (0..self.array.num_children()) - .map(|i| { - let child = self.child(i); - child.consume() - }) - .collect::>>()?; + let mut child_data = self.consume_children()?; - if let Some(d) = self.dictionary() { + if let Some(d) = self.dictionary()? { // For dictionary type there should only be a single child, so we don't need to worry if // there are other children added above. assert!(child_data.is_empty()); @@ -283,7 +300,7 @@ impl<'a> ArrowArray<'a> { // Should FFI be checking validity? Ok(unsafe { ArrayData::new_unchecked( - dt, + self.data_type, len, Some(null_count), null_bit_buffer, @@ -294,14 +311,49 @@ impl<'a> ArrowArray<'a> { }) } + fn consume_children(&self) -> Result> { + match &self.data_type { + DataType::List(field) + | DataType::FixedSizeList(field, _) + | DataType::LargeList(field) + | DataType::Map(field, _) => Ok([self.consume_child(0, field.data_type())?].to_vec()), + DataType::Struct(fields) => { + assert!(fields.len() == self.array.num_children()); + fields + .iter() + .enumerate() + .map(|(i, field)| self.consume_child(i, field.data_type())) + .collect::>>() + } + DataType::Union(union_fields, _) => { + assert!(union_fields.len() == self.array.num_children()); + union_fields + .iter() + .enumerate() + .map(|(i, (_, field))| self.consume_child(i, field.data_type())) + .collect::>>() + } + _ => Ok(Vec::new()), + } + } + + fn consume_child(&self, index: usize, child_type: &DataType) -> Result { + ImportedArrowArray { + array: self.array.child(index), + data_type: child_type.clone(), + owner: self.owner, + } + .consume() + } + /// returns all buffers, as organized by Rust (i.e. null buffer is skipped if it's present /// in the spec of the type) - fn buffers(&self, can_contain_null_mask: bool, dt: &DataType) -> Result> { + fn buffers(&self, can_contain_null_mask: bool) -> Result> { // + 1: skip null buffer let buffer_begin = can_contain_null_mask as usize; (buffer_begin..self.array.num_buffers()) .map(|index| { - let len = self.buffer_len(index, dt)?; + let len = self.buffer_len(index, &self.data_type)?; match unsafe { create_buffer(self.owner.clone(), self.array, index, len) } { Some(buf) => Ok(buf), @@ -388,25 +440,20 @@ impl<'a> ArrowArray<'a> { unsafe { create_buffer(self.owner.clone(), self.array, 0, buffer_len) } } - fn child(&self, index: usize) -> ArrowArray { - ArrowArray { - array: self.array.child(index), - schema: self.schema.child(index), - owner: self.owner, - } - } - - fn dictionary(&self) -> Option { - match (self.array.dictionary(), self.schema.dictionary()) { - (Some(array), Some(schema)) => Some(ArrowArray { + fn dictionary(&self) -> Result> { + match (self.array.dictionary(), &self.data_type) { + (Some(array), DataType::Dictionary(_, value_type)) => Ok(Some(ImportedArrowArray { array, - schema, + data_type: value_type.as_ref().clone(), owner: self.owner, - }), - (None, None) => None, - _ => panic!( - "Dictionary should both be set or not set in FFI_ArrowArray and FFI_ArrowSchema" - ), + })), + (Some(_), _) => Err(ArrowError::CDataInterface( + "Got dictionary in FFI_ArrowArray for non-dictionary data type".to_string(), + )), + (None, DataType::Dictionary(_, _)) => Err(ArrowError::CDataInterface( + "Missing dictionary in FFI_ArrowArray for dictionary data type".to_string(), + )), + (_, _) => Ok(None), } } } @@ -443,7 +490,7 @@ mod tests { let (array, schema) = to_ffi(&array.into_data()).unwrap(); // (simulate consumer) import it - let array = Int32Array::from(from_ffi(array, &schema).unwrap()); + let array = Int32Array::from(unsafe { from_ffi(array, &schema) }.unwrap()); let array = kernels::numeric::add(&array, &array).unwrap(); // verify @@ -487,7 +534,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -517,7 +564,7 @@ mod tests { let (array, schema) = to_ffi(&original_array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -539,7 +586,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -608,7 +655,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // downcast @@ -648,7 +695,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -693,7 +740,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -719,7 +766,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -748,7 +795,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -784,7 +831,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -845,7 +892,7 @@ mod tests { let (array, schema) = to_ffi(&list_data)?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -890,7 +937,7 @@ mod tests { let (array, schema) = to_ffi(&dict_array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -928,7 +975,7 @@ mod tests { } // (simulate consumer) import it - let data = from_ffi(out_array, &out_schema)?; + let data = unsafe { from_ffi(out_array, &out_schema) }?; let array = make_array(data); // perform some operation @@ -949,7 +996,7 @@ mod tests { let (array, schema) = to_ffi(&array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -986,7 +1033,7 @@ mod tests { let (array, schema) = to_ffi(&map_array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -1009,7 +1056,7 @@ mod tests { let (array, schema) = to_ffi(&struct_array.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); // perform some operation @@ -1033,7 +1080,7 @@ mod tests { let (array, schema) = to_ffi(&union.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = make_array(data); let array = array.as_any().downcast_ref::().unwrap(); @@ -1094,7 +1141,7 @@ mod tests { let (array, schema) = to_ffi(&union.to_data())?; // (simulate consumer) import it - let data = from_ffi(array, &schema)?; + let data = unsafe { from_ffi(array, &schema) }?; let array = UnionArray::from(data); let expected_type_ids = vec![0_i8, 0, 1, 0]; diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index 123669aa61be..bbec71e8837e 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -357,9 +357,11 @@ impl Iterator for ArrowArrayStreamReader { } let schema_ref = self.schema(); + // NOTE: this parses the FFI_ArrowSchema again on each iterator call; + // should probably use from_ffi_and_data_type() instead. let schema = FFI_ArrowSchema::try_from(schema_ref.as_ref()).ok()?; - let data = from_ffi(array, &schema).ok()?; + let data = unsafe { from_ffi(array, &schema) }.ok()?; let record_batch = RecordBatch::from(StructArray::from(data)); @@ -464,7 +466,7 @@ mod tests { break; } - let array = from_ffi(ffi_array, &ffi_schema).unwrap(); + let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap(); let record_batch = RecordBatch::from(StructArray::from(array)); produced_batches.push(record_batch); diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 2ac550ad0456..8302f8741b60 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -267,7 +267,7 @@ impl FromPyArrow for ArrayData { let schema_ptr = unsafe { schema_capsule.reference::() }; let array = unsafe { FFI_ArrowArray::from_raw(array_capsule.pointer() as _) }; - return ffi::from_ffi(array, schema_ptr).map_err(to_py_err); + return unsafe { ffi::from_ffi(array, schema_ptr) }.map_err(to_py_err); } validate_class("Array", value)?; @@ -287,7 +287,7 @@ impl FromPyArrow for ArrayData { ), )?; - ffi::from_ffi(array, &schema).map_err(to_py_err) + unsafe { ffi::from_ffi(array, &schema) }.map_err(to_py_err) } } @@ -348,7 +348,7 @@ impl FromPyArrow for RecordBatch { let schema_ptr = unsafe { schema_capsule.reference::() }; let ffi_array = unsafe { FFI_ArrowArray::from_raw(array_capsule.pointer() as _) }; - let array_data = ffi::from_ffi(ffi_array, schema_ptr).map_err(to_py_err)?; + let array_data = unsafe { ffi::from_ffi(ffi_array, schema_ptr) }.map_err(to_py_err)?; if !matches!(array_data.data_type(), DataType::Struct(_)) { return Err(PyTypeError::new_err( "Expected Struct type from __arrow_c_array.",