Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support nested schema projection (#5148) #5149

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 231 additions & 1 deletion arrow-schema/src/fields.rs
Original file line number Diff line number Diff line change
@@ -15,10 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::{ArrowError, Field, FieldRef, SchemaBuilder};
use std::ops::Deref;
use std::sync::Arc;

use crate::{ArrowError, DataType, Field, FieldRef, SchemaBuilder};

/// A cheaply cloneable, owned slice of [`FieldRef`]
///
/// Similar to `Arc<Vec<FieldRef>>` or `Arc<[FieldRef]>`
@@ -99,6 +100,108 @@ impl Fields {
.all(|(a, b)| Arc::ptr_eq(a, b) || a.contains(b))
}

/// Returns a copy of this [`Fields`] containing only those [`FieldRef`] passing a predicate
///
/// Performs a depth-first scan of [`Fields`] invoking `filter` for each [`FieldRef`]
/// containing no child [`FieldRef`], a leaf field, along with a count of the number
/// of such leaves encountered so far. Only [`FieldRef`] for which `filter`
/// returned `true` will be included in the result.
///
/// This can therefore be used to select a subset of fields from nested types
/// such as [`DataType::Struct`] or [`DataType::List`].
///
/// ```
/// # use arrow_schema::{DataType, Field, Fields};
/// let fields = Fields::from(vec![
/// Field::new("a", DataType::Int32, true), // Leaf 0
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("c", DataType::Float32, false), // Leaf 1
/// Field::new("d", DataType::Float64, false), // Leaf 2
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false), // Leaf 3
/// Field::new("g", DataType::Float16, false), // Leaf 4
/// ])), true),
/// ])), false)
/// ]);
/// let filtered = fields.filter_leaves(|idx, _| [0, 2, 3, 4].contains(&idx));
/// let expected = Fields::from(vec![
/// Field::new("a", DataType::Int32, true),
/// Field::new("b", DataType::Struct(Fields::from(vec![
/// Field::new("d", DataType::Float64, false),
/// Field::new("e", DataType::Struct(Fields::from(vec![
/// Field::new("f", DataType::Int32, false),
/// Field::new("g", DataType::Float16, false),
/// ])), true),
/// ])), false)
/// ]);
/// assert_eq!(filtered, expected);
/// ```
pub fn filter_leaves<F: FnMut(usize, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
Copy link
Contributor Author

@tustvold tustvold Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally I had two methods, one which was called for all Fields, however, it wasn't clear to me how such an API could be used - as it would lack any context as to where in the schema a given field appeared. I therefore just stuck with filter_leaves which has immediate concrete utility for #5135. A case could be made for a fully-fledged visitor pattern, but I'm not aware of any immediate use-cases for this, and therefore what the requirements might be for such an API.

FYI @Jefffrey

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder how we would support selecting a subfield that is itself a field?

For example:

{ 
  a : {
    b: {
      c: 1, 
    },
    d: 2
}

Could I use this API to support only selecting a.b?

{ 
  a : {
    b: {
      c: 1, 
    },
}

Maybe we could if the parent FieldRef was also passed

    /// calls filter(depth, parent_field, child_field)` 
    pub fn filter_leaves<F: FnMut(usize, &FieldRef, &FieldRef) -> bool>(&self, mut filter: F) -> Self {
...
}

🤔

Copy link
Contributor Author

@tustvold tustvold Nov 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You would select b by selecting all its children, if you call filter with intermediate fields you end up in a position of how much context do you provide, just the parent, what about the parent's parent - reasoning in terms of leaves is the only coherent mechanism I can think of

fn filter_field<F: FnMut(&FieldRef) -> bool>(
f: &FieldRef,
filter: &mut F,
) -> Option<FieldRef> {
use DataType::*;

let v = match f.data_type() {
Dictionary(_, v) => v.as_ref(), // Key must be integer
RunEndEncoded(_, v) => v.data_type(), // Run-ends must be integer
d => d,
};
let d = match v {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map and RunEndEncoded appear to have embedded fields too. It seems like they might also need to be handled 🤔

The usecase might be "I want only the values of the map? Or maybe Arrow can't express that concept (Apply the filter to only the values 🤔 )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Map is there, I missed RunEndEncoded 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunEndEncoded support added in 8d97950

List(child) => List(filter_field(child, filter)?),
LargeList(child) => LargeList(filter_field(child, filter)?),
Map(child, ordered) => Map(filter_field(child, filter)?, *ordered),
FixedSizeList(child, size) => FixedSizeList(filter_field(child, filter)?, *size),
Struct(fields) => {
let filtered: Fields = fields
.iter()
.filter_map(|f| filter_field(f, filter))
.collect();

if filtered.is_empty() {
return None;
}

Struct(filtered)
}
Union(fields, mode) => {
let filtered: UnionFields = fields
.iter()
.filter_map(|(id, f)| Some((id, filter_field(f, filter)?)))
.collect();

if filtered.is_empty() {
return None;
}

Union(filtered, *mode)
}
_ => return filter(f).then(|| f.clone()),
};
let d = match f.data_type() {
Dictionary(k, _) => Dictionary(k.clone(), Box::new(d)),
RunEndEncoded(v, f) => {
RunEndEncoded(v.clone(), Arc::new(f.as_ref().clone().with_data_type(d)))
}
_ => d,
};
Some(Arc::new(f.as_ref().clone().with_data_type(d)))
}

let mut leaf_idx = 0;
let mut filter = |f: &FieldRef| {
let t = filter(leaf_idx, f);
leaf_idx += 1;
t
};

self.0
.iter()
.filter_map(|f| filter_field(f, &mut filter))
.collect()
}

/// Remove a field by index and return it.
///
/// # Panic
@@ -307,3 +410,130 @@ impl FromIterator<(i8, FieldRef)> for UnionFields {
Self(iter.into_iter().collect())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::UnionMode;

#[test]
fn test_filter() {
let floats = Fields::from(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]);
let fields = Fields::from(vec![
Field::new("a", DataType::Int32, true),
Field::new("floats", DataType::Struct(floats.clone()), true),
Field::new("b", DataType::Int16, true),
Field::new(
"c",
DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)),
false,
),
Field::new(
"d",
DataType::Dictionary(
Box::new(DataType::Int32),
Box::new(DataType::Struct(floats.clone())),
),
false,
),
Field::new_list(
"e",
Field::new("floats", DataType::Struct(floats.clone()), true),
true,
),
Field::new(
"f",
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Int32, false)), 3),
false,
),
Field::new_map(
"g",
"entries",
Field::new("keys", DataType::LargeUtf8, false),
Field::new("values", DataType::Int32, true),
false,
false,
),
Field::new(
"h",
DataType::Union(
UnionFields::new(
vec![1, 3],
vec![
Field::new("field1", DataType::UInt8, false),
Field::new("field3", DataType::Utf8, false),
],
),
UnionMode::Dense,
),
true,
),
Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", DataType::Struct(floats.clone()), true)),
),
false,
),
]);

let floats_a = DataType::Struct(vec![floats[0].clone()].into());

let r = fields.filter_leaves(|idx, _| idx == 0 || idx == 1);
assert_eq!(r.len(), 2);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);

let r = fields.filter_leaves(|_, f| f.name() == "a");
assert_eq!(r.len(), 5);
assert_eq!(r[0], fields[0]);
assert_eq!(r[1].data_type(), &floats_a);
assert_eq!(
r[2].data_type(),
&DataType::Dictionary(Box::new(DataType::Int32), Box::new(floats_a.clone()))
);
assert_eq!(
r[3].as_ref(),
&Field::new_list("e", Field::new("floats", floats_a.clone(), true), true)
);
assert_eq!(
r[4].as_ref(),
&Field::new(
"i",
DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", floats_a.clone(), true)),
),
false,
)
);

let r = fields.filter_leaves(|_, f| f.name() == "floats");
assert_eq!(r.len(), 0);

let r = fields.filter_leaves(|idx, _| idx == 9);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[6]);

let r = fields.filter_leaves(|idx, _| idx == 10 || idx == 11);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[7]);

let union = DataType::Union(
UnionFields::new(vec![1], vec![Field::new("field1", DataType::UInt8, false)]),
UnionMode::Dense,
);

let r = fields.filter_leaves(|idx, _| idx == 12);
assert_eq!(r.len(), 1);
assert_eq!(r[0].data_type(), &union);

let r = fields.filter_leaves(|idx, _| idx == 14 || idx == 15);
assert_eq!(r.len(), 1);
assert_eq!(r[0], fields[9]);
}
}