Skip to content

Commit

Permalink
Support nested schema projection
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Nov 29, 2023
1 parent cfdb505 commit b84408b
Showing 1 changed file with 190 additions and 1 deletion.
191 changes: 190 additions & 1 deletion arrow-schema/src/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
use std::ops::Deref;
use std::sync::Arc;

use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};

/// A cheaply cloneable, owned slice of [`FieldRef`]
///
/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
Expand Down Expand Up @@ -99,6 +100,90 @@ impl Fields {
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
}

/// Performs a depth-first scan of [`Fields`] filtering the [`FieldRef`] with no children
///
/// Returns a new [`Fields`] comprising the [`FieldRef`] for which `filter` returned `true`
///
/// ```
/// # use arrow_schema::{DataType, Field, Fields};
/// let fields = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("c", DataType::Float32, false),
/// Field::new("d", DataType::Float64, false),
/// ])), false)
/// ]);
/// let filtered = fields.filter_leaves(|idx, _| idx == 0 || idx == 2);
/// let expected = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("d", DataType::Float64, false),
/// ])), false)
/// ]);
/// assert_eq!(filtered, expected);
/// ```
pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
fn filter_field<F: FnMut(&FieldRef) -> bool>(
f: &FieldRef,
filter: &mut F,
) -> Option<FieldRef> {
use DataType::*;

let (k, v) = match f.data_type() {
Dictionary(k, v) => (Some(k.clone()), v.as_ref()),
d => (None, d),
};
let d = match v {
List(child) => List(filter_field(child, filter)?),
LargeList(child) => LargeList(filter_field(child, filter)?),
Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
Struct(fields) => {
let filtered: Fields = fields
.iter()
.filter_map(|f| filter_field(f, filter))
.collect();

if filtered.is_empty() {
return None;
}

Struct(filtered)
}
Union(fields, mode) => {
let filtered: UnionFields = fields
.iter()
.filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
.collect();

if filtered.is_empty() {
return None;
}

Union(filtered, *mode)
}
_ => return filter(f).then(|| f.clone()),
};
let d = match k {
Some(k) => Dictionary(k, Box::new(d)),
None => d,
};
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
}

let mut leaf_idx = 0;
let mut filter = |f: &FieldRef| {
let t = filter(leaf_idx, f);
leaf_idx += 1;
t
};

self.0
.iter()
.filter_map(|f| filter_field(f, &mut filter))
.collect()
}

/// Remove a field by index and return it.
///
/// # Panic
Expand Down Expand Up @@ -307,3 +392,107 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
Self(iter.into_iter().collect())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::UnionMode;

#[test]
fn test_filter() {
let floats = Fields::from(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]);
let fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("floats", DataType::Struct(floats.clone()), true),
Field::new("b", DataType::Int16, true),
Field::new(
"c",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
Field::new(
"d",
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Struct(floats.clone())),
),
false,
),
Field::new_list(
"e",
Field::new("floats", DataType::Struct(floats.clone()), true),
true,
),
Field::new(
"f",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
false,
),
Field::new_map(
"g",
"entries",
Field::new("keys", DataType::LargeUtf8, false),
Field::new("values", DataType::Int32, true),
false,
false,
),
Field::new(
"h",
DataType::Union(
UnionFields::new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
),
]);

let floats_a = DataType::Struct(vec![floats[0].clone()].into());

let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
assert_eq!(r.len(), 2);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);

let r = fields.filter_leaves(|_, f| f.name() == "a");
assert_eq!(r.len(), 4);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);
assert_eq!(
r[2].data_type(),
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
);
assert_eq!(
r[3].as_ref(),
&Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
);

let r = fields.filter_leaves(|_, f| f.name() == "floats");
assert_eq!(r.len(), 0);

let r = fields.filter_leaves(|idx, _| idx == 9);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[6]);

let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[7]);

let union = DataType::Union(
UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
UnionMode::Dense,
);

let r = fields.filter_leaves(|idx, _| idx == 12);
assert_eq!(r.len(), 1);
assert_eq!(r[0].data_type(), &union);
}
}

0 comments on commit b84408b

Please sign in to comment.