diff --git a/vortex-expr/src/lib.rs b/vortex-expr/src/lib.rs index c4c5bf5d47..e1a79959b1 100644 --- a/vortex-expr/src/lib.rs +++ b/vortex-expr/src/lib.rs @@ -10,6 +10,7 @@ mod like; mod literal; mod not; mod operators; +mod pack; mod project; pub mod pruning; mod row_filter; @@ -24,6 +25,7 @@ pub use like::*; pub use literal::*; pub use not::*; pub use operators::*; +pub use pack::*; pub use project::*; pub use row_filter::*; pub use select::*; diff --git a/vortex-expr/src/pack.rs b/vortex-expr/src/pack.rs new file mode 100644 index 0000000000..805b1dc433 --- /dev/null +++ b/vortex-expr/src/pack.rs @@ -0,0 +1,252 @@ +use std::any::Any; +use std::fmt::Display; +use std::sync::Arc; + +use itertools::Itertools as _; +use vortex_array::array::StructArray; +use vortex_array::validity::Validity; +use vortex_array::{ArrayData, IntoArrayData}; +use vortex_dtype::FieldNames; +use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; + +use crate::{ExprRef, VortexExpr}; + +/// Pack zero or more expressions into a structure with named fields. +/// +/// # Examples +/// +/// ``` +/// use vortex_array::IntoArrayData; +/// use vortex_array::compute::scalar_at; +/// use vortex_buffer::buffer; +/// use vortex_expr::{Pack, Identity, VortexExpr}; +/// use vortex_scalar::Scalar; +/// +/// let example = Pack::try_new_expr( +/// ["x".into(), "x copy".into(), "second x copy".into()].into(), +/// vec![Identity::new_expr(), Identity::new_expr(), Identity::new_expr()], +/// ).unwrap(); +/// let packed = example.evaluate(&buffer![100, 110, 200].into_array()).unwrap(); +/// let x_copy = packed +/// .as_struct_array() +/// .unwrap() +/// .field_by_name("x copy") +/// .unwrap(); +/// assert_eq!(scalar_at(&x_copy, 0).unwrap(), Scalar::from(100)); +/// assert_eq!(scalar_at(&x_copy, 1).unwrap(), Scalar::from(110)); +/// assert_eq!(scalar_at(&x_copy, 2).unwrap(), Scalar::from(200)); +/// ``` +/// +#[derive(Debug, Clone)] +pub struct Pack { + names: FieldNames, + values: Vec, +} + +impl Pack { + pub fn try_new_expr(names: FieldNames, values: Vec) -> VortexResult> { + if names.len() != values.len() { + vortex_bail!("length mismatch {} {}", names.len(), values.len()); + } + Ok(Arc::new(Pack { names, values })) + } +} + +impl PartialEq for Pack { + fn eq(&self, other: &dyn Any) -> bool { + other.downcast_ref::().is_some_and(|other_pack| { + self.names == other_pack.names + && self + .values + .iter() + .zip(other_pack.values.iter()) + .all(|(x, y)| x.eq(y)) + }) + } +} + +impl Display for Pack { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut f = f.debug_struct("Pack"); + for (name, value) in self.names.iter().zip_eq(self.values.iter()) { + f.field(name, value); + } + f.finish() + } +} + +impl VortexExpr for Pack { + fn as_any(&self) -> &dyn Any { + self + } + + fn evaluate(&self, batch: &ArrayData) -> VortexResult { + let len = batch.len(); + let value_arrays = self + .values + .iter() + .map(|value_expr| value_expr.evaluate(batch)) + .process_results(|it| it.collect::>())?; + StructArray::try_new(self.names.clone(), value_arrays, len, Validity::NonNullable) + .map(IntoArrayData::into_array) + } + + fn children(&self) -> Vec<&ExprRef> { + self.values.iter().collect() + } + + fn replacing_children(self: Arc, children: Vec) -> ExprRef { + assert_eq!(children.len(), self.values.len()); + Self::try_new_expr(self.names.clone(), children) + .vortex_expect("children are known to have the same length as names") + } +} + +impl PartialEq for Pack { + fn eq(&self, other: &Pack) -> bool { + self.names == other.names && self.values == other.values + } +} + +impl Eq for Pack {} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use vortex_array::array::{PrimitiveArray, StructArray}; + use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant as _}; + use vortex_buffer::buffer; + use vortex_dtype::{Field, FieldNames}; + use vortex_error::{vortex_bail, vortex_err, VortexResult}; + + use crate::{col, Column, Pack, VortexExpr}; + + fn test_array() -> StructArray { + StructArray::from_fields(&[ + ("a", buffer![0, 1, 2].into_array()), + ("b", buffer![4, 5, 6].into_array()), + ]) + .unwrap() + } + + fn primitive_field(array: &ArrayData, field_path: &[&str]) -> VortexResult { + let mut field_path = field_path.iter(); + + let Some(field) = field_path.next() else { + vortex_bail!("empty field path"); + }; + + let mut array = array + .as_struct_array() + .ok_or_else(|| vortex_err!("expected a struct"))? + .field_by_name(field) + .ok_or_else(|| vortex_err!("expected field to exist: {}", field))?; + + for field in field_path { + array = array + .as_struct_array() + .ok_or_else(|| vortex_err!("expected a struct"))? + .field_by_name(field) + .ok_or_else(|| vortex_err!("expected field to exist: {}", field))?; + } + Ok(array.into_primitive().unwrap()) + } + + #[test] + pub fn test_empty_pack() { + let expr = Pack::try_new_expr(Arc::new([]), Vec::new()).unwrap(); + + let test_array = test_array().into_array(); + let actual_array = expr.evaluate(&test_array).unwrap(); + assert_eq!(actual_array.len(), test_array.len()); + assert!(actual_array.as_struct_array().unwrap().nfields() == 0); + } + + #[test] + pub fn test_simple_pack() { + let expr = Pack::try_new_expr( + ["one".into(), "two".into(), "three".into()].into(), + vec![col("a"), col("b"), col("a")], + ) + .unwrap(); + + let actual_array = expr.evaluate(test_array().as_ref()).unwrap(); + let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into(); + assert_eq!( + actual_array.as_struct_array().unwrap().names(), + &expected_names + ); + + assert_eq!( + primitive_field(&actual_array, &["one"]) + .unwrap() + .as_slice::(), + [0, 1, 2] + ); + assert_eq!( + primitive_field(&actual_array, &["two"]) + .unwrap() + .as_slice::(), + [4, 5, 6] + ); + assert_eq!( + primitive_field(&actual_array, &["three"]) + .unwrap() + .as_slice::(), + [0, 1, 2] + ); + } + + #[test] + pub fn test_nested_pack() { + let expr = Pack::try_new_expr( + ["one".into(), "two".into(), "three".into()].into(), + vec![ + Column::new_expr(Field::from("a")), + Pack::try_new_expr( + ["two_one".into(), "two_two".into()].into(), + vec![ + Column::new_expr(Field::from("b")), + Column::new_expr(Field::from("b")), + ], + ) + .unwrap(), + Column::new_expr(Field::from("a")), + ], + ) + .unwrap(); + + let actual_array = expr.evaluate(test_array().as_ref()).unwrap(); + let expected_names: FieldNames = ["one".into(), "two".into(), "three".into()].into(); + assert_eq!( + actual_array.as_struct_array().unwrap().names(), + &expected_names + ); + + assert_eq!( + primitive_field(&actual_array, &["one"]) + .unwrap() + .as_slice::(), + [0, 1, 2] + ); + assert_eq!( + primitive_field(&actual_array, &["two", "two_one"]) + .unwrap() + .as_slice::(), + [4, 5, 6] + ); + assert_eq!( + primitive_field(&actual_array, &["two", "two_two"]) + .unwrap() + .as_slice::(), + [4, 5, 6] + ); + assert_eq!( + primitive_field(&actual_array, &["three"]) + .unwrap() + .as_slice::(), + [0, 1, 2] + ); + } +}