diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index 31846ee1fdc3..cb10e2c53bbb 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -34,7 +34,7 @@ use arrow_array::types::*; use arrow_array::*; use arrow_buffer::ArrowNativeType; use arrow_data::transform::{Capacities, MutableArrayData}; -use arrow_schema::{ArrowError, DataType, SchemaRef}; +use arrow_schema::{ArrowError, DataType, Schema, SchemaRef}; fn binary_capacity(arrays: &[&dyn Array]) -> Capacities { let mut item_capacity = 0; @@ -112,7 +112,7 @@ pub fn concat_batches<'a>( if let Some((i, _)) = batches .iter() .enumerate() - .find(|&(_, batch)| batch.schema() != *schema) + .find(|&(_, batch)| !concatable_schema(schema.as_ref(), batch.schema().as_ref())) { return Err(ArrowError::InvalidArgumentError(format!( "batches[{i}] schema is different with argument schema. @@ -137,12 +137,31 @@ pub fn concat_batches<'a>( RecordBatch::try_new(schema.clone(), arrays) } +/// Returns true if data with the `source` Schema can be placed in a +/// record batch with `target` Schema +fn concatable_schema(target: &Schema, source: &Schema) -> bool { + // ignore metadata + // https://github.com/apache/arrow-rs/issues/4799 + if source.fields().len() != target.fields().len() { + return false; + } + + source.fields().iter().zip(target.fields().iter()).all( + |(source_field, target_field)| { + // also ignore nullabulity as `RecordBatch::try_new()` + // will validate that + source_field.name() == target_field.name() + && source_field.data_type() == target_field.data_type() + }, + ) +} + #[cfg(test)] mod tests { use super::*; use arrow_array::cast::AsArray; use arrow_schema::{Field, Schema}; - use std::sync::Arc; + use std::{collections::HashMap, sync::Arc}; #[test] fn test_concat_empty_vec() { @@ -680,6 +699,74 @@ mod tests { ); } + #[test] + fn concat_record_batches_of_different_metadata() { + let metadata = HashMap::from([("foo".to_string(), "bar".to_string())]); + let field = Field::new("a", DataType::Int32, false); + + let schema1 = Arc::new(Schema::new(vec![field.clone()])); + + let batch1 = + RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![1]))]) + .unwrap(); + + let schema2 = Arc::new(Schema::new(vec![field.with_metadata(metadata)])); + + let batch2 = + RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3]))]) + .unwrap(); + + // should be able to concat batches with different metadata + let new_batch = concat_batches(&batch1.schema(), [&batch1, &batch2]).unwrap(); + assert_eq!(new_batch.schema(), batch1.schema()); + assert_eq!(2, new_batch.num_rows()); + + // using batch2 schema should also work + let new_batch = concat_batches(&batch2.schema(), [&batch1, &batch2]).unwrap(); + assert_eq!(new_batch.schema(), batch2.schema()); + assert_eq!(2, new_batch.num_rows()); + } + + #[test] + fn concat_record_batches_of_different_nullability() { + // is nullable + let field = Field::new("a", DataType::Int32, true); + let nullable_schema = Arc::new(Schema::new(vec![field.clone()])); + + let batch_with_nulls = RecordBatch::try_new( + nullable_schema, + vec![Arc::new(Int32Array::from(vec![Some(1), None]))], + ) + .unwrap(); + + let non_nullable_schema = Arc::new(Schema::new(vec![field.with_nullable(false)])); + + let batch_without_nulls = RecordBatch::try_new( + non_nullable_schema, + vec![Arc::new(Int32Array::from(vec![3]))], + ) + .unwrap(); + + // should be able to concat batches if the schema says it is + // nullable + let new_batch = concat_batches( + &batch_with_nulls.schema(), + [&batch_with_nulls, &batch_without_nulls], + ) + .unwrap(); + assert_eq!(new_batch.schema(), batch_with_nulls.schema()); + assert_eq!(3, new_batch.num_rows()); + + // should not be able to concat batches with nulls together if + // the schema says it is not nullable + let err = concat_batches( + &batch_without_nulls.schema(), + [&batch_with_nulls, &batch_without_nulls], + ) + .unwrap_err(); + assert_eq!(err.to_string(), "Invalid argument error: Column 'a' is declared as non-nullable but contains null values"); + } + #[test] fn concat_capacity() { let a = Int32Array::from_iter_values(0..100);