Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow concat_batches to work with RecordBatches that have different metadata #4800

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use arrow_schema::{ArrowError, DataType, Schema, SchemaRef};
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Expand Down Expand Up @@ -179,7 +179,7 @@ pub fn concat_batches<'a>(
if let Some((i, _)) = batches
.iter()
.enumerate()
.find(|&(_, batch)| batch.schema() != *schema)
.find(|&(_, batch)| !concatable_schema(schema.as_ref(), batch.schema().as_ref()))
{
return Err(ArrowError::InvalidArgumentError(format!(
"batches[{i}] schema is different with argument schema.
Expand All @@ -204,13 +204,32 @@ pub fn concat_batches<'a>(
RecordBatch::try_new(schema.clone(), arrays)
}

/// Returns true if data with the `source` Schema can be placed in a
/// record batch with `target` Schema
fn concatable_schema(target: &Schema, source: &Schema) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// ignore metadata
// https://github.com/apache/arrow-rs/issues/4799
if source.fields().len() != target.fields().len() {
return false;
}

source.fields().iter().zip(target.fields().iter()).all(
|(source_field, target_field)| {
// also ignore nullabulity as `RecordBatch::try_new()`
// will validate that
source_field.name() == target_field.name()
&& source_field.data_type() == target_field.data_type()
},
)
}

#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::cast::AsArray;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

#[test]
fn test_concat_empty_vec() {
Expand Down Expand Up @@ -759,6 +778,74 @@ mod tests {
);
}

#[test]
fn concat_record_batches_of_different_metadata() {
let metadata = HashMap::from([("foo".to_string(), "bar".to_string())]);
let field = Field::new("a", DataType::Int32, false);

let schema1 = Arc::new(Schema::new(vec![field.clone()]));

let batch1 =
RecordBatch::try_new(schema1, vec![Arc::new(Int32Array::from(vec![1]))])
.unwrap();

let schema2 = Arc::new(Schema::new(vec![field.with_metadata(metadata)]));

let batch2 =
RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3]))])
.unwrap();

// should be able to concat batches with different metadata
let new_batch = concat_batches(&batch1.schema(), [&batch1, &batch2]).unwrap();
assert_eq!(new_batch.schema(), batch1.schema());
assert_eq!(2, new_batch.num_rows());

// using batch2 schema should also work
let new_batch = concat_batches(&batch2.schema(), [&batch1, &batch2]).unwrap();
assert_eq!(new_batch.schema(), batch2.schema());
assert_eq!(2, new_batch.num_rows());
}

#[test]
fn concat_record_batches_of_different_nullability() {
// is nullable
let field = Field::new("a", DataType::Int32, true);
let nullable_schema = Arc::new(Schema::new(vec![field.clone()]));

let batch_with_nulls = RecordBatch::try_new(
nullable_schema,
vec![Arc::new(Int32Array::from(vec![Some(1), None]))],
)
.unwrap();

let non_nullable_schema = Arc::new(Schema::new(vec![field.with_nullable(false)]));

let batch_without_nulls = RecordBatch::try_new(
non_nullable_schema,
vec![Arc::new(Int32Array::from(vec![3]))],
)
.unwrap();

// should be able to concat batches if the schema says it is
// nullable
let new_batch = concat_batches(
&batch_with_nulls.schema(),
[&batch_with_nulls, &batch_without_nulls],
)
.unwrap();
assert_eq!(new_batch.schema(), batch_with_nulls.schema());
assert_eq!(3, new_batch.num_rows());

// should not be able to concat batches with nulls together if
// the schema says it is not nullable
let err = concat_batches(
&batch_without_nulls.schema(),
[&batch_with_nulls, &batch_without_nulls],
)
.unwrap_err();
assert_eq!(err.to_string(), "Invalid argument error: Column 'a' is declared as non-nullable but contains null values");
}

#[test]
fn concat_capacity() {
let a = Int32Array::from_iter_values(0..100);
Expand Down