From c9029c53883d984bdb7421d0d37ad7c65c8fe95a Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Mon, 12 Feb 2024 06:57:11 -0500 Subject: [PATCH] Don't omit schema metadata when removing column (#5328) * Don't omit schema metadata when removing column * Add test * Update arrow-schema/src/schema.rs Co-authored-by: Andrew Lamb --------- Co-authored-by: Andrew Lamb --- arrow-array/src/record_batch.rs | 31 ++++++++++++++++++++++++++++++- arrow-schema/src/schema.rs | 6 ++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index dab6ae343a77..d89020a65681 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -355,7 +355,7 @@ impl RecordBatch { /// assert_eq!(batch.num_columns(), 1); /// ``` pub fn remove_column(&mut self, index: usize) -> ArrayRef { - let mut builder = SchemaBuilder::from(self.schema.fields()); + let mut builder = SchemaBuilder::from(self.schema.as_ref()); builder.remove(index); self.schema = Arc::new(builder.finish()); self.columns.remove(index) @@ -618,6 +618,8 @@ where #[cfg(test)] mod tests { + use std::collections::HashMap; + use super::*; use crate::{BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray}; use arrow_buffer::{Buffer, ToByteSlice}; @@ -1155,4 +1157,31 @@ mod tests { let size = get_size(reader); assert_eq!(size, 0); } + + #[test] + fn test_remove_column_maintains_schema_metadata() { + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let bool_array = BooleanArray::from(vec![true, false, false, true, true]); + + let mut metadata = HashMap::new(); + metadata.insert("foo".to_string(), "bar".to_string()); + let schema = Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("bool", DataType::Boolean, false), + ]) + .with_metadata(metadata); + + let mut batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(id_array), Arc::new(bool_array)], + ) + .unwrap(); + + let _removed_column = batch.remove_column(0); + assert_eq!(batch.schema().metadata().len(), 1); + assert_eq!( + batch.schema().metadata().get("foo").unwrap().as_str(), + "bar" + ); + } } diff --git a/arrow-schema/src/schema.rs b/arrow-schema/src/schema.rs index e547e5df3a5a..ede158fcf248 100644 --- a/arrow-schema/src/schema.rs +++ b/arrow-schema/src/schema.rs @@ -140,6 +140,12 @@ impl From for SchemaBuilder { } } +impl From<&Schema> for SchemaBuilder { + fn from(value: &Schema) -> Self { + Self::from(value.clone()) + } +} + impl From for SchemaBuilder { fn from(value: Schema) -> Self { Self {