From b84408b663bde4b1d3a4c80f1dbcae3d4bad813c Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 29 Nov 2023 13:33:50 +0000 Subject: [PATCH 1/3] Support nested schema projection --- arrow-schema/src/fields.rs | 191 ++++++++++++++++++++++++++++++++++++- 1 file changed, 190 insertions(+), 1 deletion(-) diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index f90632455fd9..b2ecbb796312 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::{ArrowError, Field, FieldRef, SchemaBuilder}; use std::ops::Deref; use std::sync::Arc; +use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder}; + /// A cheaply cloneable, owned slice of [`FieldRef`] /// /// Similar to `Arc>` or `Arc<[FieldRef]>` @@ -99,6 +100,90 @@ impl Fields { .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } + /// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children + /// + /// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true` + /// + /// ``` + /// # use arrow_schema::{DataType, Field, Fields}; + /// let fields = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("c", DataType::Float32, false), + /// Field::new("d", DataType::Float64, false), + /// ])), false) + /// ]); + /// let filtered = fields.filter_leaves(|idx, _| idx == 0 || idx == 2); + /// let expected = Fields::from(vec![ + /// Field::new("a", DataType::Int32, true), + /// Field::new("b", DataType::Struct(Fields::from(vec![ + /// Field::new("d", DataType::Float64, false), + /// ])), false) + /// ]); + /// assert_eq!(filtered, expected); + /// ``` + pub fn filter_leaves bool>(&self, mut filter: F) -> Self { + fn filter_field bool>( + f: &FieldRef, + filter: &mut F, + ) -> Option { + use DataType::*; + + let (k, v) = match f.data_type() { + Dictionary(k, v) => (Some(k.clone()), v.as_ref()), + d => (None, d), + }; + let d = match v { + List(child) => List(filter_field(child, filter)?), + LargeList(child) => LargeList(filter_field(child, filter)?), + Map(child, ordered) => Map(filter_field(child, filter)?, *ordered), + FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size), + Struct(fields) => { + let filtered: Fields = fields + .iter() + .filter_map(|f| filter_field(f, filter)) + .collect(); + + if filtered.is_empty() { + return None; + } + + Struct(filtered) + } + Union(fields, mode) => { + let filtered: UnionFields = fields + .iter() + .filter_map(|(id, f)| Some((id, filter_field(f, filter)?))) + .collect(); + + if filtered.is_empty() { + return None; + } + + Union(filtered, *mode) + } + _ => return filter(f).then(|| f.clone()), + }; + let d = match k { + Some(k) => Dictionary(k, Box::new(d)), + None => d, + }; + Some(Arc::new(f.as_ref().clone().with_data_type(d))) + } + + let mut leaf_idx = 0; + let mut filter = |f: &FieldRef| { + let t = filter(leaf_idx, f); + leaf_idx += 1; + t + }; + + self.0 + .iter() + .filter_map(|f| filter_field(f, &mut filter)) + .collect() + } + /// Remove a field by index and return it. /// /// # Panic @@ -307,3 +392,107 @@ impl FromIterator<(i8, FieldRef)> for UnionFields { Self(iter.into_iter().collect()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::UnionMode; + + #[test] + fn test_filter() { + let floats = Fields::from(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ]); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("floats", DataType::Struct(floats.clone()), true), + Field::new("b", DataType::Int16, true), + Field::new( + "c", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), + Field::new( + "d", + DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Struct(floats.clone())), + ), + false, + ), + Field::new_list( + "e", + Field::new("floats", DataType::Struct(floats.clone()), true), + true, + ), + Field::new( + "f", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3), + false, + ), + Field::new_map( + "g", + "entries", + Field::new("keys", DataType::LargeUtf8, false), + Field::new("values", DataType::Int32, true), + false, + false, + ), + Field::new( + "h", + DataType::Union( + UnionFields::new( + vec![1, 3], + vec![ + Field::new("field1", DataType::UInt8, false), + Field::new("field3", DataType::Utf8, false), + ], + ), + UnionMode::Dense, + ), + true, + ), + ]); + + let floats_a = DataType::Struct(vec![floats[0].clone()].into()); + + let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1); + assert_eq!(r.len(), 2); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + + let r = fields.filter_leaves(|_, f| f.name() == "a"); + assert_eq!(r.len(), 4); + assert_eq!(r[0], fields[0]); + assert_eq!(r[1].data_type(), &floats_a); + assert_eq!( + r[2].data_type(), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone())) + ); + assert_eq!( + r[3].as_ref(), + &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true) + ); + + let r = fields.filter_leaves(|_, f| f.name() == "floats"); + assert_eq!(r.len(), 0); + + let r = fields.filter_leaves(|idx, _| idx == 9); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[6]); + + let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[7]); + + let union = DataType::Union( + UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]), + UnionMode::Dense, + ); + + let r = fields.filter_leaves(|idx, _| idx == 12); + assert_eq!(r.len(), 1); + assert_eq!(r[0].data_type(), &union); + } +} From 0aab1d17518f6c66ce3eb7d33abadd52a5f78f01 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 29 Nov 2023 13:49:34 +0000 Subject: [PATCH 2/3] Tweak doc --- arrow-schema/src/fields.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index b2ecbb796312..70c18ba08eea 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -102,6 +102,9 @@ impl Fields { /// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children /// + /// Invokes `filter` with each leaf [`FieldRef`], i.e. one containing no children, and a + /// count of the number of previous calls to `filter` - i.e. the leaf's index. + /// /// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true` /// /// ``` From 8d9795069c8244820b2107c76f46d9d484ee4740 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 29 Nov 2023 20:17:40 +0000 Subject: [PATCH 3/3] Review feedback --- arrow-schema/src/fields.rs | 68 +++++++++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/arrow-schema/src/fields.rs b/arrow-schema/src/fields.rs index 70c18ba08eea..400f42c59c30 100644 --- a/arrow-schema/src/fields.rs +++ b/arrow-schema/src/fields.rs @@ -100,27 +100,38 @@ impl Fields { .all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b)) } - /// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children + /// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate /// - /// Invokes `filter` with each leaf [`FieldRef`], i.e. one containing no children, and a - /// count of the number of previous calls to `filter` - i.e. the leaf's index. + /// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`] + /// containing no child [`FieldRef`], a leaf field, along with a count of the number + /// of such leaves encountered so far. Only [`FieldRef`] for which `filter` + /// returned `true` will be included in the result. /// - /// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true` + /// This can therefore be used to select a subset of fields from nested types + /// such as [`DataType::Struct`] or [`DataType::List`]. /// /// ``` /// # use arrow_schema::{DataType, Field, Fields}; /// let fields = Fields::from(vec![ - /// Field::new("a", DataType::Int32, true), + /// Field::new("a", DataType::Int32, true), // Leaf 0 /// Field::new("b", DataType::Struct(Fields::from(vec![ - /// Field::new("c", DataType::Float32, false), - /// Field::new("d", DataType::Float64, false), + /// Field::new("c", DataType::Float32, false), // Leaf 1 + /// Field::new("d", DataType::Float64, false), // Leaf 2 + /// Field::new("e", DataType::Struct(Fields::from(vec![ + /// Field::new("f", DataType::Int32, false), // Leaf 3 + /// Field::new("g", DataType::Float16, false), // Leaf 4 + /// ])), true), /// ])), false) /// ]); - /// let filtered = fields.filter_leaves(|idx, _| idx == 0 || idx == 2); + /// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx)); /// let expected = Fields::from(vec![ /// Field::new("a", DataType::Int32, true), /// Field::new("b", DataType::Struct(Fields::from(vec![ /// Field::new("d", DataType::Float64, false), + /// Field::new("e", DataType::Struct(Fields::from(vec![ + /// Field::new("f", DataType::Int32, false), + /// Field::new("g", DataType::Float16, false), + /// ])), true), /// ])), false) /// ]); /// assert_eq!(filtered, expected); @@ -132,9 +143,10 @@ impl Fields { ) -> Option { use DataType::*; - let (k, v) = match f.data_type() { - Dictionary(k, v) => (Some(k.clone()), v.as_ref()), - d => (None, d), + let v = match f.data_type() { + Dictionary(_, v) => v.as_ref(), // Key must be integer + RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer + d => d, }; let d = match v { List(child) => List(filter_field(child, filter)?), @@ -167,9 +179,12 @@ impl Fields { } _ => return filter(f).then(|| f.clone()), }; - let d = match k { - Some(k) => Dictionary(k, Box::new(d)), - None => d, + let d = match f.data_type() { + Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)), + RunEndEncoded(v, f) => { + RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d))) + } + _ => d, }; Some(Arc::new(f.as_ref().clone().with_data_type(d))) } @@ -456,6 +471,14 @@ mod tests { ), true, ), + Field::new( + "i", + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)), + ), + false, + ), ]); let floats_a = DataType::Struct(vec![floats[0].clone()].into()); @@ -466,7 +489,7 @@ mod tests { assert_eq!(r[1].data_type(), &floats_a); let r = fields.filter_leaves(|_, f| f.name() == "a"); - assert_eq!(r.len(), 4); + assert_eq!(r.len(), 5); assert_eq!(r[0], fields[0]); assert_eq!(r[1].data_type(), &floats_a); assert_eq!( @@ -477,6 +500,17 @@ mod tests { r[3].as_ref(), &Field::new_list("e", Field::new("floats", floats_a.clone(), true), true) ); + assert_eq!( + r[4].as_ref(), + &Field::new( + "i", + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", floats_a.clone(), true)), + ), + false, + ) + ); let r = fields.filter_leaves(|_, f| f.name() == "floats"); assert_eq!(r.len(), 0); @@ -497,5 +531,9 @@ mod tests { let r = fields.filter_leaves(|idx, _| idx == 12); assert_eq!(r.len(), 1); assert_eq!(r[0].data_type(), &union); + + let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15); + assert_eq!(r.len(), 1); + assert_eq!(r[0], fields[9]); } }