diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index b78c785ae279..0e66d71d4733 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -875,4 +875,43 @@ mod tests { UnionMode::Dense, ); } + + #[test] + fn test_struct_reorder() { + let mut datafields = DataType::Struct(Fields::from(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, false), + ])); + let reverse_datafields = DataType::Struct(Fields::from(vec![ + Field::new("f2", DataType::Utf8, false), + Field::new("f1", DataType::Int32, false), + ])); + match datafields { + DataType::Struct(ref mut fields) => { + fields.reverse(); + } + _ => {} + }; + assert_eq!(datafields, reverse_datafields); + } + + #[test] + fn test_struct_push() { + let mut datafields = DataType::Struct(Fields::from(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, false), + ])); + match datafields { + DataType::Struct(ref mut fields) => { + fields.push(Field::new("f3", DataType::Boolean, false)); + } + _ => {} + }; + let expected_datafields = DataType::Struct(Fields::from(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Utf8, false), + Field::new("f3", DataType::Boolean, false), + ])); + assert_eq!(datafields, expected_datafields); + } } diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index 07e9abeee56a..eae6b42e0c68 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -41,7 +41,7 @@ use std::sync::Arc; #[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", serde(transparent))] -pub struct Fields(Arc<[FieldRef]>); +pub struct Fields(Arc>); impl std::fmt::Debug for Fields { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -50,9 +50,13 @@ impl std::fmt::Debug for Fields { } impl Fields { + pub fn new(fields: Arc>) -> Self { + Self(fields) + } + /// Returns a new empty [`Fields`] pub fn empty() -> Self { - Self(Arc::new([])) + Self(Arc::new(vec![])) } /// Return size of this instance in bytes. @@ -83,6 +87,16 @@ impl Fields { .zip(other.iter()) .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } + + pub fn reverse(&mut self) { + let new_fields: Vec = self.iter().rev().map(|f| f.clone() as FieldRef).collect(); + self.0 = Arc::new(new_fields); + } + + pub fn push(&mut self, field: Field) { + let fields = Arc::make_mut(&mut self.0); + fields.push(Arc::new(field)); + } } impl Default for Fields { @@ -99,7 +113,7 @@ impl FromIterator for Fields { impl FromIterator for Fields { fn from_iter>(iter: T) -> Self { - Self(iter.into_iter().collect()) + Self(Arc::new(iter.into_iter().map(|f| f as FieldRef).collect())) } } @@ -117,13 +131,13 @@ impl From> for Fields { impl From<&[FieldRef]> for Fields { fn from(value: &[FieldRef]) -> Self { - Self(value.into()) + Self(Arc::new(value.to_vec())) } } impl From<[FieldRef; N]> for Fields { fn from(value: [FieldRef; N]) -> Self { - Self(Arc::new(value)) + Self(Arc::new(value.to_vec())) } }