diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index c34c3d3d0ccf..a6dcca24eace 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -159,7 +159,12 @@ fn concat_fallback( Ok(make_array(mutable.freeze())) } -/// Concatenates `batches` together into a single record batch. +/// Concatenates `batches` together into a single [`RecordBatch`]. +/// +/// The output batch has the specified `schemas`; The schema of the +/// input are ignored. +/// +/// Returns an error if the types of underlying arrays are different. pub fn concat_batches<'a>( schema: &SchemaRef, input_batches: impl IntoIterator, @@ -176,20 +181,6 @@ pub fn concat_batches<'a>( if batches.is_empty() { return Ok(RecordBatch::new_empty(schema.clone())); } - if let Some((i, _)) = batches - .iter() - .enumerate() - .find(|&(_, batch)| batch.schema() != *schema) - { - return Err(ArrowError::InvalidArgumentError(format!( - "batches[{i}] schema is different with argument schema. - batches[{i}] schema: {:?}, - argument schema: {:?} - ", - batches[i].schema(), - *schema - ))); - } let field_num = schema.fields().len(); let mut arrays = Vec::with_capacity(field_num); for i in 0..field_num { @@ -727,36 +718,45 @@ mod tests { } #[test] - fn concat_record_batches_of_different_schemas() { - let schema1 = Arc::new(Schema::new(vec![ - Field::new("a", DataType::Int32, false), - Field::new("b", DataType::Utf8, false), - ])); - let schema2 = Arc::new(Schema::new(vec![ - Field::new("c", DataType::Int32, false), - Field::new("d", DataType::Utf8, false), - ])); + fn concat_record_batches_of_different_schemas_but_compatible_data() { + let schema1 = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // column names differ + let schema2 = + Arc::new(Schema::new(vec![Field::new("c", DataType::Int32, false)])); let batch1 = RecordBatch::try_new( schema1.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2])), - Arc::new(StringArray::from(vec!["a", "b"])), - ], + vec![Arc::new(Int32Array::from(vec![1, 2]))], + ) + .unwrap(); + let batch2 = + RecordBatch::try_new(schema2, vec![Arc::new(Int32Array::from(vec![3, 4]))]) + .unwrap(); + // concat_batches simply uses the schema provided + let batch = concat_batches(&schema1, [&batch1, &batch2]).unwrap(); + assert_eq!(batch.schema().as_ref(), schema1.as_ref()); + assert_eq!(4, batch.num_rows()); + } + + #[test] + fn concat_record_batches_of_different_schemas_incompatible_data() { + let schema1 = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)])); + // column names differ + let schema2 = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)])); + let batch1 = RecordBatch::try_new( + schema1.clone(), + vec![Arc::new(Int32Array::from(vec![1, 2]))], ) .unwrap(); let batch2 = RecordBatch::try_new( schema2, - vec![ - Arc::new(Int32Array::from(vec![3, 4])), - Arc::new(StringArray::from(vec!["c", "d"])), - ], + vec![Arc::new(StringArray::from(vec!["foo", "bar"]))], ) .unwrap(); + let error = concat_batches(&schema1, [&batch1, &batch2]).unwrap_err(); - assert_eq!( - error.to_string(), - "Invalid argument error: batches[1] schema is different with argument schema.\n batches[1] schema: Schema { fields: [Field { name: \"c\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"d\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} },\n argument schema: Schema { fields: [Field { name: \"a\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"b\", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }\n " - ); + assert_eq!(error.to_string(), "Invalid argument error: It is not possible to concatenate arrays of different data types."); } #[test]