diff --git a/Cargo.lock b/Cargo.lock index 1fa63b7165..0d01f2263d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5174,7 +5174,7 @@ dependencies = [ "vortex-expr", "vortex-file", "vortex-io", - "vortex-scan", + "vortex-scalar", ] [[package]] diff --git a/encodings/bytebool/src/stats.rs b/encodings/bytebool/src/stats.rs index 794e5ab4c0..0078354516 100644 --- a/encodings/bytebool/src/stats.rs +++ b/encodings/bytebool/src/stats.rs @@ -88,8 +88,8 @@ mod tests { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!(bool_arr.statistics().compute(Stat::Min), None); - assert_eq!(bool_arr.statistics().compute(Stat::Max), None); + assert!(bool_arr.statistics().compute(Stat::Min).is_none()); + assert!(bool_arr.statistics().compute(Stat::Max).is_none()); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); } diff --git a/encodings/datetime-parts/src/stats.rs b/encodings/datetime-parts/src/stats.rs index 131181330e..4090660136 100644 --- a/encodings/datetime-parts/src/stats.rs +++ b/encodings/datetime-parts/src/stats.rs @@ -1,14 +1,14 @@ use vortex_array::stats::{Stat, StatisticsVTable, StatsSet}; use vortex_array::ArrayLen; use vortex_error::VortexResult; -use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use crate::{DateTimePartsArray, DateTimePartsEncoding}; impl StatisticsVTable for DateTimePartsEncoding { fn compute_statistics(&self, array: &DateTimePartsArray, stat: Stat) -> VortexResult { let maybe_stat = match stat { - Stat::NullCount => Some(Scalar::from(array.validity().null_count(array.len())?)), + Stat::NullCount => Some(ScalarValue::from(array.validity().null_count(array.len())?)), _ => None, }; diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index aa55bd3261..74f0f29241 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -8,7 +8,7 @@ use vortex_buffer::{Buffer, BufferMut}; use vortex_dtype::{ match_each_integer_ptype, match_each_unsigned_integer_ptype, DType, NativePType, Nullability, }; -use vortex_error::{vortex_bail, vortex_err, VortexExpect, VortexResult}; +use vortex_error::{vortex_bail, vortex_err, VortexExpect, VortexResult, VortexUnwrap}; use vortex_scalar::Scalar; use vortex_sparse::SparseArray; @@ -21,10 +21,11 @@ pub fn for_compress(array: PrimitiveArray) -> VortexResult { .compute(Stat::Min) .ok_or_else(|| vortex_err!("Min stat not found"))?; - let nullability = array.dtype().nullability(); + let dtype = array.dtype().clone(); + let nullability = dtype.nullability(); let encoded = match_each_integer_ptype!(array.ptype(), |$T| { if shift == <$T>::PTYPE.bit_width() as u8 { - assert_eq!(min, Scalar::zero::<$T>(array.dtype().nullability())); + assert_eq!(usize::try_from(&min).vortex_unwrap(), 0); encoded_zero::<$T>(array.validity().to_logical(array.len()), nullability) .vortex_expect("Failed to encode all zeroes") } else { @@ -34,7 +35,7 @@ pub fn for_compress(array: PrimitiveArray) -> VortexResult { .into_array() } }); - FoRArray::try_new(encoded, min, shift) + FoRArray::try_new(encoded, Scalar::new(dtype, min), shift) } fn encoded_zero( @@ -48,8 +49,7 @@ fn encoded_zero( } let encoded_ptype = T::PTYPE.to_unsigned(); - let zero = - match_each_unsigned_integer_ptype!(encoded_ptype, |$T| Scalar::zero::<$T>(nullability)); + let zero = match_each_unsigned_integer_ptype!(encoded_ptype, |$T| Scalar::primitive($T::default(), nullability)); Ok(match logical_validity { LogicalValidity::AllValid(len) => ConstantArray::new(zero, len).into_array(), diff --git a/encodings/runend/src/statistics.rs b/encodings/runend/src/statistics.rs index ba5bc1599e..39622f7d9f 100644 --- a/encodings/runend/src/statistics.rs +++ b/encodings/runend/src/statistics.rs @@ -8,7 +8,7 @@ use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType as _, ArrayLen as _, IntoArrayVariant as _}; use vortex_dtype::{match_each_unsigned_integer_ptype, DType, NativePType}; use vortex_error::VortexResult; -use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use crate::{RunEndArray, RunEndEncoding}; @@ -16,7 +16,7 @@ impl StatisticsVTable for RunEndEncoding { fn compute_statistics(&self, array: &RunEndArray, stat: Stat) -> VortexResult { let maybe_stat = match stat { Stat::Min | Stat::Max => array.values().statistics().compute(stat), - Stat::IsSorted => Some(Scalar::from( + Stat::IsSorted => Some(ScalarValue::from( array .values() .statistics() @@ -25,10 +25,10 @@ impl StatisticsVTable for RunEndEncoding { && array.logical_validity().all_valid(), )), Stat::TrueCount => match array.dtype() { - DType::Bool(_) => Some(Scalar::from(array.true_count()?)), + DType::Bool(_) => Some(ScalarValue::from(array.true_count()?)), _ => None, }, - Stat::NullCount => Some(Scalar::from(array.null_count()?)), + Stat::NullCount => Some(ScalarValue::from(array.null_count()?)), _ => None, }; diff --git a/encodings/sparse/src/lib.rs b/encodings/sparse/src/lib.rs index bbff78a88c..3b1a15cd51 100644 --- a/encodings/sparse/src/lib.rs +++ b/encodings/sparse/src/lib.rs @@ -143,13 +143,13 @@ impl SparseArray { #[inline] pub fn fill_scalar(&self) -> Scalar { - let fill_value = ScalarValue::from_flexbytes( + let sv = ScalarValue::from_flexbytes( self.as_ref() .byte_buffer(0) .vortex_expect("Missing fill value buffer"), ) .vortex_expect("Failed to deserialize fill value"); - Scalar::new(self.dtype().clone(), fill_value) + Scalar::new(self.dtype().clone(), sv) } } @@ -173,14 +173,14 @@ impl StatisticsVTable for SparseEncoding { let fill_stats = if array.fill_scalar().is_null() { StatsSet::nulls(fill_len, array.dtype()) } else { - StatsSet::constant(&array.fill_scalar(), fill_len) + StatsSet::constant(array.fill_scalar(), fill_len) }; if values.is_empty() { return Ok(fill_stats); } - Ok(stats.merge_unordered(&fill_stats)) + Ok(stats.merge_unordered(&fill_stats, array.dtype())) } } diff --git a/encodings/zigzag/src/array.rs b/encodings/zigzag/src/array.rs index 1ae2d4cf10..5fbe82a727 100644 --- a/encodings/zigzag/src/array.rs +++ b/encodings/zigzag/src/array.rs @@ -11,7 +11,7 @@ use vortex_array::{ }; use vortex_dtype::{DType, PType}; use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexExpect as _, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::ScalarValue; use zigzag::ZigZag as ExternalZigZag; use crate::compress::zigzag_encode; @@ -98,15 +98,12 @@ impl StatisticsVTable for ZigZagEncoding { stats.set(stat, val); } } else if matches!(stat, Stat::Min | Stat::Max) { - let encoded_max = array - .encoded() - .statistics() - .compute_as_cast::(Stat::Max); + let encoded_max = array.encoded().statistics().compute_as::(Stat::Max); if let Some(val) = encoded_max { // the max of the encoded array is the element with the highest absolute value (so either min if negative, or max if positive) let decoded = ::decode(val); let decoded_stat = if decoded < 0 { Stat::Min } else { Stat::Max }; - stats.set(decoded_stat, Scalar::from(decoded).cast(array.dtype())?); + stats.set(decoded_stat, ScalarValue::from(decoded)); } } @@ -125,6 +122,7 @@ mod test { use vortex_array::compute::{scalar_at, slice}; use vortex_array::IntoArrayData; use vortex_buffer::buffer; + use vortex_scalar::Scalar; use super::*; @@ -133,19 +131,36 @@ mod test { let array = buffer![1i32, -5i32, 2, 3, 4, 5, 6, 7, 8, 9, 10].into_array(); let zigzag = ZigZagArray::encode(&array).unwrap(); - for stat in [Stat::Max, Stat::NullCount, Stat::IsConstant] { - let value = zigzag.statistics().compute(stat); - assert_eq!(value, array.statistics().compute(stat)); - } + assert_eq!( + zigzag.statistics().compute_max::(), + array.statistics().compute_max::() + ); + assert_eq!( + zigzag.statistics().compute_null_count(), + array.statistics().compute_null_count() + ); + assert_eq!( + zigzag.statistics().compute_is_constant(), + array.statistics().compute_is_constant() + ); let sliced = ZigZagArray::try_from(slice(zigzag, 0, 2).unwrap()).unwrap(); assert_eq!( scalar_at(&sliced, sliced.len() - 1).unwrap(), Scalar::from(-5i32) ); - for stat in [Stat::Min, Stat::NullCount, Stat::IsConstant] { - let value = sliced.statistics().compute(stat); - assert_eq!(value, array.statistics().compute(stat)); - } + + assert_eq!( + sliced.statistics().compute_min::(), + array.statistics().compute_min::() + ); + assert_eq!( + sliced.statistics().compute_null_count(), + array.statistics().compute_null_count() + ); + assert_eq!( + sliced.statistics().compute_is_constant(), + array.statistics().compute_is_constant() + ); } } diff --git a/vortex-array/src/array/bool/mod.rs b/vortex-array/src/array/bool/mod.rs index 255e189169..5f3d0cf6b7 100644 --- a/vortex-array/src/array/bool/mod.rs +++ b/vortex-array/src/array/bool/mod.rs @@ -4,7 +4,7 @@ use arrow_array::BooleanArray; use arrow_buffer::MutableBuffer; use vortex_buffer::{Alignment, ByteBuffer}; use vortex_dtype::{DType, Nullability}; -use vortex_error::{vortex_bail, VortexError, VortexExpect as _, VortexResult}; +use vortex_error::{vortex_bail, VortexExpect as _, VortexResult}; use crate::encoding::ids; use crate::stats::StatsSet; @@ -13,8 +13,8 @@ use crate::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTabl use crate::variants::{BoolArrayTrait, VariantsVTable}; use crate::visitor::{ArrayVisitor, VisitorVTable}; use crate::{ - impl_encoding, ArrayData, ArrayLen, Canonical, DeserializeMetadata, IntoArrayData, - IntoCanonical, RkyvMetadata, + impl_encoding, ArrayLen, Canonical, DeserializeMetadata, IntoArrayData, IntoCanonical, + RkyvMetadata, }; pub mod compute; diff --git a/vortex-array/src/array/bool/stats.rs b/vortex-array/src/array/bool/stats.rs index e8408847cd..6903e2e683 100644 --- a/vortex-array/src/array/bool/stats.rs +++ b/vortex-array/src/array/bool/stats.rs @@ -276,8 +276,8 @@ mod test { assert!(!bool_arr.statistics().compute_is_strict_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_sorted().unwrap()); assert!(bool_arr.statistics().compute_is_constant().unwrap()); - assert_eq!(bool_arr.statistics().compute(Stat::Min), None); - assert_eq!(bool_arr.statistics().compute(Stat::Max), None); + assert!(bool_arr.statistics().compute(Stat::Min).is_none()); + assert!(bool_arr.statistics().compute(Stat::Max).is_none()); assert_eq!(bool_arr.statistics().compute_run_count().unwrap(), 1); assert_eq!(bool_arr.statistics().compute_true_count().unwrap(), 0); assert_eq!(bool_arr.statistics().compute_null_count().unwrap(), 5); diff --git a/vortex-array/src/array/chunked/mod.rs b/vortex-array/src/array/chunked/mod.rs index b768a681f9..0df452ba52 100644 --- a/vortex-array/src/array/chunked/mod.rs +++ b/vortex-array/src/array/chunked/mod.rs @@ -5,13 +5,10 @@ use std::fmt::{Debug, Display}; use futures_util::stream; -use rkyv::{access, to_bytes}; use serde::{Deserialize, Serialize}; use vortex_buffer::BufferMut; use vortex_dtype::{DType, Nullability, PType}; -use vortex_error::{ - vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult, VortexUnwrap, -}; +use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap}; use crate::array::primitive::PrimitiveArray; use crate::compute::{scalar_at, search_sorted_usize, SearchSortedSide}; diff --git a/vortex-array/src/array/chunked/stats.rs b/vortex-array/src/array/chunked/stats.rs index 9292650fc6..b00e9d1a33 100644 --- a/vortex-array/src/array/chunked/stats.rs +++ b/vortex-array/src/array/chunked/stats.rs @@ -3,6 +3,7 @@ use vortex_error::VortexResult; use crate::array::chunked::ChunkedArray; use crate::array::ChunkedEncoding; use crate::stats::{ArrayStatistics, Stat, StatisticsVTable, StatsSet}; +use crate::ArrayDType; impl StatisticsVTable for ChunkedEncoding { fn compute_statistics(&self, array: &ChunkedArray, stat: Stat) -> VortexResult { @@ -20,7 +21,7 @@ impl StatisticsVTable for ChunkedEncoding { } .unwrap_or_default() }) - .reduce(|acc, x| acc.merge_ordered(&x)) + .reduce(|acc, x| acc.merge_ordered(&x, array.dtype())) .unwrap_or_default()) } } diff --git a/vortex-array/src/array/constant/canonical.rs b/vortex-array/src/array/constant/canonical.rs index 18fa42c01b..ba407dc6bc 100644 --- a/vortex-array/src/array/constant/canonical.rs +++ b/vortex-array/src/array/constant/canonical.rs @@ -113,6 +113,7 @@ fn canonical_byte_view( #[cfg(test)] mod tests { + use enum_iterator::all; use vortex_dtype::half::f16; use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; @@ -120,8 +121,8 @@ mod tests { use crate::array::ConstantArray; use crate::canonical::IntoArrayVariant; use crate::compute::scalar_at; - use crate::stats::{ArrayStatistics as _, StatsSet}; - use crate::{ArrayLen, IntoArrayData as _, IntoCanonical}; + use crate::stats::{ArrayStatistics as _, Stat, StatsSet}; + use crate::{ArrayDType, ArrayLen, IntoArrayData as _, IntoCanonical}; #[test] fn test_canonicalize_null() { @@ -154,8 +155,23 @@ mod tests { let canonical = const_array.into_canonical().unwrap(); let canonical_stats = canonical.statistics().to_set(); - assert_eq!(canonical_stats, StatsSet::constant(&scalar, 4)); - assert_eq!(canonical_stats, stats); + let reference = StatsSet::constant(scalar, 4); + for stat in all::() { + let canonical_stat = canonical_stats + .get(stat) + .cloned() + .map(|sv| Scalar::new(stat.dtype(canonical.dtype()), sv)); + let reference_stat = reference + .get(stat) + .cloned() + .map(|sv| Scalar::new(stat.dtype(canonical.dtype()), sv)); + let original_stat = stats + .get(stat) + .cloned() + .map(|sv| Scalar::new(stat.dtype(canonical.dtype()), sv)); + assert_eq!(canonical_stat, reference_stat); + assert_eq!(canonical_stat, original_stat); + } } #[test] diff --git a/vortex-array/src/array/constant/mod.rs b/vortex-array/src/array/constant/mod.rs index 0cc367cb37..a53b11638d 100644 --- a/vortex-array/src/array/constant/mod.rs +++ b/vortex-array/src/array/constant/mod.rs @@ -1,5 +1,4 @@ use std::fmt::Display; -use std::num::IntErrorKind::Empty; use serde::{Deserialize, Serialize}; use vortex_error::{VortexExpect, VortexResult}; @@ -25,7 +24,7 @@ impl ConstantArray { S: Into, { let scalar = scalar.into(); - let stats = StatsSet::constant(&scalar, length); + let stats = StatsSet::constant(scalar.clone(), length); let (dtype, scalar_value) = scalar.into_parts(); // Serialize the scalar_value into a FlatBuffer @@ -44,13 +43,13 @@ impl ConstantArray { /// Returns the [`Scalar`] value of this constant array. pub fn scalar(&self) -> Scalar { - let value = ScalarValue::from_flexbytes( + let sv = ScalarValue::from_flexbytes( self.as_ref() .byte_buffer(0) .vortex_expect("Missing scalar value buffer"), ) .vortex_expect("Failed to deserialize scalar value"); - Scalar::new(self.dtype().clone(), value) + Scalar::new(self.dtype().clone(), sv) } } @@ -71,7 +70,7 @@ impl ValidityVTable for ConstantEncoding { impl StatisticsVTable for ConstantEncoding { fn compute_statistics(&self, array: &ConstantArray, _stat: Stat) -> VortexResult { - Ok(StatsSet::constant(&array.scalar(), array.len())) + Ok(StatsSet::constant(array.scalar(), array.len())) } } diff --git a/vortex-array/src/array/extension/mod.rs b/vortex-array/src/array/extension/mod.rs index fd66a3b9fb..66377065a6 100644 --- a/vortex-array/src/array/extension/mod.rs +++ b/vortex-array/src/array/extension/mod.rs @@ -1,9 +1,7 @@ use std::fmt::{Debug, Display}; use std::sync::Arc; -use enum_iterator::all; use serde::{Deserialize, Serialize}; -use vortex_buffer::ByteBuffer; use vortex_dtype::{DType, ExtDType, ExtID}; use vortex_error::{VortexExpect as _, VortexResult}; @@ -34,7 +32,7 @@ impl ExtensionArray { EmptyMetadata, None, Some([storage].into()), - Default::default(), + StatsSet::default(), ) .vortex_expect("Invalid ExtensionArray") } @@ -93,17 +91,7 @@ impl VisitorVTable for ExtensionEncoding { impl StatisticsVTable for ExtensionEncoding { fn compute_statistics(&self, array: &ExtensionArray, stat: Stat) -> VortexResult { - let mut stats = array.storage().statistics().compute_all(&[stat])?; - - // for e.g., min/max, we want to cast to the extension array's dtype - // for other stats, we don't need to change anything - for stat in all::().filter(|s| s.has_same_dtype_as_array()) { - if let Some(value) = stats.get(stat) { - stats.set(stat, value.cast(array.dtype())?); - } - } - - Ok(stats) + array.storage().statistics().compute_all(&[stat]) } } @@ -111,7 +99,6 @@ impl StatisticsVTable for ExtensionEncoding { mod tests { use vortex_buffer::buffer; use vortex_dtype::PType; - use vortex_scalar::Scalar; use super::*; use crate::IntoArrayData; @@ -123,7 +110,7 @@ mod tests { DType::from(PType::I64).into(), None, )); - let array = ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array()); + let array = ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array()); let stats = array .statistics() @@ -136,14 +123,8 @@ mod tests { num_stats ); - assert_eq!( - stats.get(Stat::Min), - Some(&Scalar::extension(ext_dtype.clone(), Scalar::from(1_i64))) - ); - assert_eq!( - stats.get(Stat::Max), - Some(&Scalar::extension(ext_dtype, Scalar::from(5_i64))) - ); - assert_eq!(stats.get(Stat::NullCount), Some(&0u64.into())); + assert_eq!(stats.get_as::(Stat::Min), Some(1i64)); + assert_eq!(stats.get_as::(Stat::Max), Some(5_i64)); + assert_eq!(stats.get_as::(Stat::NullCount), Some(0)); } } diff --git a/vortex-array/src/array/primitive/stats.rs b/vortex-array/src/array/primitive/stats.rs index 9b74439b0c..05d9645ed5 100644 --- a/vortex-array/src/array/primitive/stats.rs +++ b/vortex-array/src/array/primitive/stats.rs @@ -7,8 +7,8 @@ use itertools::{Itertools as _, MinMaxResult}; use num_traits::PrimInt; use vortex_dtype::half::f16; use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability}; -use vortex_error::{vortex_panic, VortexResult}; -use vortex_scalar::Scalar; +use vortex_error::{vortex_panic, VortexError, VortexResult}; +use vortex_scalar::ScalarValue; use crate::array::primitive::PrimitiveArray; use crate::array::PrimitiveEncoding; @@ -18,9 +18,18 @@ use crate::validity::{ArrayValidity, LogicalValidity}; use crate::variants::PrimitiveArrayTrait; use crate::{ArrayDType, IntoArrayVariant}; -trait PStatsType: NativePType + Into + BitWidth {} +trait PStatsType: + NativePType + Into + BitWidth + for<'a> TryFrom<&'a ScalarValue, Error = VortexError> +{ +} -impl + BitWidth> PStatsType for T {} +impl PStatsType for T where + T: NativePType + + Into + + BitWidth + + for<'a> TryFrom<&'a ScalarValue, Error = VortexError> +{ +} impl StatisticsVTable for PrimitiveEncoding { fn compute_statistics(&self, array: &PrimitiveArray, stat: Stat) -> VortexResult { @@ -43,10 +52,10 @@ impl StatisticsVTable for PrimitiveEncoding { })?; if let Some(min) = stats.get(Stat::Min) { - stats.set(Stat::Min, min.cast(array.dtype())?); + stats.set(Stat::Min, min.clone()); } if let Some(max) = stats.get(Stat::Max) { - stats.set(Stat::Max, max.cast(array.dtype())?); + stats.set(Stat::Max, max.clone()); } Ok(stats) } @@ -64,8 +73,8 @@ impl StatisticsVTable<[T]> for PrimitiveEncoding { stats.set( Stat::IsConstant, stats - .get(Stat::Min) - .zip(stats.get(Stat::Max)) + .get_as::(Stat::Min) + .zip(stats.get_as::(Stat::Max)) .map(|(min, max)| min == max) .unwrap_or(false), ); @@ -169,7 +178,7 @@ fn compute_min_max( match iter.minmax_by(|a, b| a.total_compare(*b)) { MinMaxResult::NoElements => StatsSet::default(), MinMaxResult::OneElement(x) => { - let scalar: Scalar = x.into(); + let scalar = x.into(); StatsSet::new_unchecked(vec![ (Stat::Min, scalar.clone()), (Stat::Max, scalar), @@ -334,8 +343,6 @@ impl BitWidthAccumulator { #[cfg(test)] mod test { - use vortex_scalar::Scalar; - use crate::array::primitive::PrimitiveArray; use crate::stats::{ArrayStatistics, Stat}; @@ -399,9 +406,9 @@ mod test { #[test] fn all_null() { let arr = PrimitiveArray::from_option_iter([Option::::None, None, None]); - let min: Option = arr.statistics().compute(Stat::Min); - let max: Option = arr.statistics().compute(Stat::Max); - assert_eq!(min, None); - assert_eq!(max, None); + let min = arr.statistics().compute(Stat::Min); + let max = arr.statistics().compute(Stat::Max); + assert!(min.is_none()); + assert!(max.is_none()); } } diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index 9a55093bb6..354e892a38 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -4,13 +4,12 @@ use itertools::{Itertools, MinMaxResult}; use vortex_buffer::ByteBuffer; use vortex_error::{vortex_panic, VortexResult}; -use super::varbin_scalar; use crate::accessor::ArrayAccessor; use crate::array::varbin::VarBinArray; -use crate::array::VarBinEncoding; +use crate::array::{varbin_scalar, VarBinEncoding}; use crate::compute::scalar_at; use crate::stats::{Stat, StatisticsVTable, StatsSet}; -use crate::ArrayTrait; +use crate::{ArrayDType, ArrayTrait}; impl StatisticsVTable for VarBinEncoding { fn compute_statistics(&self, array: &VarBinArray, stat: Stat) -> VortexResult { @@ -53,7 +52,7 @@ pub fn compute_varbin_statistics>( let is_constant = array.with_iterator(compute_is_constant)?; if is_constant { // we know that the array is not empty - StatsSet::constant(&scalar_at(array, 0)?, array.len()) + StatsSet::constant(scalar_at(array, 0)?, array.len()) } else { StatsSet::of(Stat::IsConstant, is_constant) } @@ -112,12 +111,12 @@ fn compute_min_max>(array: &T) -> VortexResu let minmax = array.with_iterator(|iter| match iter.flatten().minmax() { MinMaxResult::NoElements => None, MinMaxResult::OneElement(value) => { - let scalar = varbin_scalar(ByteBuffer::from(value.to_vec()), array.dtype()); + let scalar = ByteBuffer::from(value.to_vec()); Some((scalar.clone(), scalar)) } MinMaxResult::MinMax(min, max) => Some(( - varbin_scalar(ByteBuffer::from(min.to_vec()), array.dtype()), - varbin_scalar(ByteBuffer::from(max.to_vec()), array.dtype()), + ByteBuffer::from(min.to_vec()), + ByteBuffer::from(max.to_vec()), )), })?; let Some((min, max)) = minmax else { @@ -129,7 +128,10 @@ fn compute_min_max>(array: &T) -> VortexResu // get (don't compute) null count if `min == max` to determine if it's constant if array.statistics().get_as::(Stat::NullCount) == Some(0) { // if there are no nulls, then the array is constant - return Ok(StatsSet::constant(&min, array.len())); + return Ok(StatsSet::constant( + varbin_scalar(min, array.dtype()), + array.len(), + )); } } else { stats.set(Stat::IsConstant, false); diff --git a/vortex-array/src/compress.rs b/vortex-array/src/compress.rs index 7498486644..bcff80f2c9 100644 --- a/vortex-array/src/compress.rs +++ b/vortex-array/src/compress.rs @@ -1,9 +1,10 @@ use vortex_error::VortexResult; +use vortex_scalar::Scalar; use crate::aliases::hash_set::HashSet; use crate::encoding::EncodingRef; use crate::stats::{ArrayStatistics as _, PRUNING_STATS}; -use crate::ArrayData; +use crate::{ArrayDType, ArrayData}; pub trait CompressionStrategy { fn compress(&self, array: &ArrayData) -> VortexResult; @@ -64,15 +65,18 @@ pub fn check_statistics_unchanged(arr: &ArrayData, compressed: &ArrayData) { .into_iter() .filter(|(stat, _)| *stat != Stat::RunCount) { + let compressed_scalar = compressed + .statistics() + .get(stat) + .map(|sv| Scalar::new(stat.dtype(compressed.dtype()), sv)); debug_assert_eq!( - compressed.statistics().get(stat), - Some(value.clone()), + compressed_scalar, + Some(Scalar::new(stat.dtype(arr.dtype()), value.clone())), "Compression changed {stat} from {value} to {}", - compressed - .statistics() - .get(stat) + compressed_scalar + .as_ref() .map(|s| s.to_string()) - .unwrap_or_else(|| "null".to_string()) + .unwrap_or_else(|| "null".to_string()), ); } } diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 0233e30326..64a7979e62 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -115,7 +115,7 @@ impl TryFrom for Mask { ); } - if let Some(true_count) = array.statistics().get_as_cast::(Stat::TrueCount) { + if let Some(true_count) = array.statistics().get_as::(Stat::TrueCount) { let len = array.len(); if true_count == 0 { return Ok(Self::new_false(len)); @@ -133,8 +133,6 @@ impl TryFrom for Mask { #[cfg(test)] mod test { - use itertools::Itertools; - use super::*; use crate::array::{BoolArray, PrimitiveArray}; use crate::compute::filter::filter; diff --git a/vortex-array/src/data/statistics.rs b/vortex-array/src/data/statistics.rs index f7a022df64..faacaa5768 100644 --- a/vortex-array/src/data/statistics.rs +++ b/vortex-array/src/data/statistics.rs @@ -1,17 +1,14 @@ -use std::sync::Arc; - use enum_iterator::all; use itertools::Itertools; -use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{vortex_panic, VortexExpect as _}; -use vortex_scalar::{Scalar, ScalarValue}; +use vortex_scalar::ScalarValue; use crate::data::InnerArrayData; use crate::stats::{Stat, Statistics, StatsSet}; -use crate::{ArrayDType, ArrayData}; +use crate::ArrayData; impl Statistics for ArrayData { - fn get(&self, stat: Stat) -> Option { + fn get(&self, stat: Stat) -> Option { match &self.0 { InnerArrayData::Owned(o) => o .stats_set @@ -28,12 +25,10 @@ impl Statistics for ArrayData { Stat::Max => { let max = v.flatbuffer().stats()?.max(); max.and_then(|v| ScalarValue::try_from(v).ok()) - .map(|v| Scalar::new(self.dtype().clone(), v)) } Stat::Min => { let min = v.flatbuffer().stats()?.min(); min.and_then(|v| ScalarValue::try_from(v).ok()) - .map(|v| Scalar::new(self.dtype().clone(), v)) } Stat::IsConstant => v.flatbuffer().stats()?.is_constant().map(bool::into), Stat::IsSorted => v.flatbuffer().stats()?.is_sorted().map(bool::into), @@ -41,21 +36,18 @@ impl Statistics for ArrayData { Stat::RunCount => v.flatbuffer().stats()?.run_count().map(u64::into), Stat::TrueCount => v.flatbuffer().stats()?.true_count().map(u64::into), Stat::NullCount => v.flatbuffer().stats()?.null_count().map(u64::into), - Stat::BitWidthFreq => { - let element_dtype = - Arc::new(DType::Primitive(PType::U64, Nullability::NonNullable)); - v.flatbuffer() - .stats()? - .bit_width_freq() - .map(|v| v.iter().map(Scalar::from).collect_vec()) - .map(|v| Scalar::list(element_dtype, v, Nullability::NonNullable)) - } + Stat::BitWidthFreq => v + .flatbuffer() + .stats()? + .bit_width_freq() + .map(|v| v.iter().collect_vec()) + .map(ScalarValue::from), Stat::TrailingZeroFreq => v .flatbuffer() .stats()? .trailing_zero_freq() .map(|v| v.iter().collect_vec()) - .map(|v| v.into()), + .map(ScalarValue::from), Stat::UncompressedSizeInBytes => v .flatbuffer() .stats()? @@ -78,7 +70,7 @@ impl Statistics for ArrayData { } } - fn set(&self, stat: Stat, value: Scalar) { + fn set(&self, stat: Stat, value: ScalarValue) { match &self.0 { InnerArrayData::Owned(o) => o .stats_set @@ -111,7 +103,7 @@ impl Statistics for ArrayData { } } - fn compute(&self, stat: Stat) -> Option { + fn compute(&self, stat: Stat) -> Option { if let Some(s) = self.get(stat) { return Some(s); } diff --git a/vortex-array/src/patches.rs b/vortex-array/src/patches.rs index 9ce149ba07..225a34cc4c 100644 --- a/vortex-array/src/patches.rs +++ b/vortex-array/src/patches.rs @@ -86,7 +86,7 @@ impl Patches { "Patch indices must be shorter than the array length" ); assert!(!indices.is_empty(), "Patch indices must not be empty"); - if let Some(max) = indices.statistics().get_as_cast::(Stat::Max) { + if let Some(max) = indices.statistics().get_as::(Stat::Max) { assert!( max < array_len as u64, "Patch indices {} are longer than the array length {}", diff --git a/vortex-array/src/stats/flatbuffers.rs b/vortex-array/src/stats/flatbuffers.rs index f7ceb6b0e5..b915f1870f 100644 --- a/vortex-array/src/stats/flatbuffers.rs +++ b/vortex-array/src/stats/flatbuffers.rs @@ -21,13 +21,9 @@ impl WriteFlatBuffer for &dyn Statistics { .map(|v| v.iter().copied().collect_vec()) .map(|v| fbb.create_vector(v.as_slice())); - let min = self - .get(Stat::Min) - .map(|min| min.into_value().write_flatbuffer(fbb)); + let min = self.get(Stat::Min).map(|min| min.write_flatbuffer(fbb)); - let max = self - .get(Stat::Max) - .map(|max| max.into_value().write_flatbuffer(fbb)); + let max = self.get(Stat::Max).map(|max| max.write_flatbuffer(fbb)); let stat_args = &crate::flatbuffers::ArrayStatsArgs { min, @@ -35,12 +31,12 @@ impl WriteFlatBuffer for &dyn Statistics { is_sorted: self.get_as::(Stat::IsSorted), is_strict_sorted: self.get_as::(Stat::IsStrictSorted), is_constant: self.get_as::(Stat::IsConstant), - run_count: self.get_as_cast::(Stat::RunCount), - true_count: self.get_as_cast::(Stat::TrueCount), - null_count: self.get_as_cast::(Stat::NullCount), + run_count: self.get_as::(Stat::RunCount), + true_count: self.get_as::(Stat::TrueCount), + null_count: self.get_as::(Stat::NullCount), bit_width_freq, trailing_zero_freq, - uncompressed_size_in_bytes: self.get_as_cast::(Stat::UncompressedSizeInBytes), + uncompressed_size_in_bytes: self.get_as::(Stat::UncompressedSizeInBytes), }; crate::flatbuffers::ArrayStats::create(fbb, stat_args) diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 74c75281fb..3b5c02006f 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -7,6 +7,7 @@ use std::sync::Arc; use arrow_buffer::bit_iterator::BitIterator; use arrow_buffer::{BooleanBufferBuilder, MutableBuffer}; use enum_iterator::{cardinality, Sequence}; +use futures_util::TryStreamExt; use itertools::Itertools; use log::debug; use num_enum::{IntoPrimitive, TryFromPrimitive}; @@ -14,7 +15,7 @@ pub use statsset::*; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, NativePType, PType}; use vortex_error::{vortex_panic, VortexError, VortexExpect, VortexResult}; -use vortex_scalar::Scalar; +use vortex_scalar::{Scalar, ScalarValue}; use crate::encoding::Encoding; use crate::ArrayData; @@ -167,13 +168,13 @@ impl Display for Stat { pub trait Statistics { /// Returns the value of the statistic only if it's present - fn get(&self, stat: Stat) -> Option; + fn get(&self, stat: Stat) -> Option; /// Get all existing statistics fn to_set(&self) -> StatsSet; /// Set the value of the statistic - fn set(&self, stat: Stat, value: Scalar); + fn set(&self, stat: Stat, value: ScalarValue); /// Clear the value of the statistic fn clear(&self, stat: Stat); @@ -182,7 +183,7 @@ pub trait Statistics { /// /// Returns the scalar if compute succeeded, or `None` if the stat is not supported /// for this array. - fn compute(&self, stat: Stat) -> Option; + fn compute(&self, stat: Stat) -> Option; /// Compute all the requested statistics (if not already present) /// Returns a StatsSet with the requested stats and any additional available stats @@ -225,7 +226,13 @@ where } impl dyn Statistics + '_ { - pub fn get_as TryFrom<&'a Scalar, Error = VortexError>>( + /// Get the provided stat if present in the underlying array, converting the `ScalarValue` into a typed value. + /// If the stored `ScalarValue` is of different type then the primitive typed value this function will perform a cast. + /// + /// # Panics + /// + /// This function will panic if the conversion fails. + pub fn get_as TryFrom<&'a ScalarValue, Error = VortexError>>( &self, stat: Stat, ) -> Option { @@ -242,24 +249,13 @@ impl dyn Statistics + '_ { }) } - pub fn get_as_cast TryFrom<&'a Scalar, Error = VortexError>>( - &self, - stat: Stat, - ) -> Option { - self.get(stat) - .filter(|s| s.is_valid()) - .map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable))) - .transpose() - .and_then(|maybe| maybe.as_ref().map(U::try_from).transpose()) - .unwrap_or_else(|err| { - vortex_panic!(err, "Failed to cast stat {} to {}", stat, U::PTYPE) - }) - } - - /// Get or calculate the provided stat, converting the `Scalar` into a typed value. + /// Get or calculate the provided stat, converting the `ScalarValue` into a typed value. + /// If the stored `ScalarValue` is of different type then the primitive typed value this function will perform a cast. + /// + /// # Panics /// /// This function will panic if the conversion fails. - pub fn compute_as TryFrom<&'a Scalar, Error = VortexError>>( + pub fn compute_as TryFrom<&'a ScalarValue, Error = VortexError>>( &self, stat: Stat, ) -> Option { @@ -276,31 +272,21 @@ impl dyn Statistics + '_ { }) } - pub fn compute_as_cast TryFrom<&'a Scalar, Error = VortexError>>( - &self, - stat: Stat, - ) -> Option { - self.compute(stat) - .filter(|s| s.is_valid()) - .map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable))) - .transpose() - .and_then(|maybe| maybe.as_ref().map(U::try_from).transpose()) - .unwrap_or_else(|err| { - vortex_panic!(err, "Failed to compute stat {} as cast {}", stat, U::PTYPE) - }) - } - /// Get or calculate the minimum value in the array, returning as a typed value. /// /// This function will panic if the conversion fails. - pub fn compute_min TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option { + pub fn compute_min TryFrom<&'a ScalarValue, Error = VortexError>>( + &self, + ) -> Option { self.compute_as(Stat::Min) } /// Get or calculate the maximum value in the array, returning as a typed value. /// /// This function will panic if the conversion fails. - pub fn compute_max TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option { + pub fn compute_max TryFrom<&'a ScalarValue, Error = VortexError>>( + &self, + ) -> Option { self.compute_as(Stat::Max) } @@ -367,7 +353,7 @@ mod test { fn min_of_nulls_is_not_panic() { let min = PrimitiveArray::from_option_iter::([None, None, None, None]) .statistics() - .compute_as_cast::(Stat::Min); + .compute_as::(Stat::Min); assert_eq!(min, None); } diff --git a/vortex-array/src/stats/statsset.rs b/vortex-array/src/stats/statsset.rs index 5e24598a48..9c5660a96c 100644 --- a/vortex-array/src/stats/statsset.rs +++ b/vortex-array/src/stats/statsset.rs @@ -1,14 +1,14 @@ use enum_iterator::{all, Sequence}; use itertools::{EitherOrBoth, Itertools}; use vortex_dtype::DType; -use vortex_error::{vortex_panic, VortexError, VortexExpect}; -use vortex_scalar::Scalar; +use vortex_error::{vortex_panic, VortexError, VortexExpect, VortexUnwrap}; +use vortex_scalar::{Scalar, ScalarValue}; use crate::stats::Stat; -#[derive(Default, Debug, Clone, PartialEq, Eq)] +#[derive(Default, Debug, Clone)] pub struct StatsSet { - values: Option>, + values: Option>, } impl StatsSet { @@ -17,17 +17,12 @@ impl StatsSet { /// # Safety /// /// This method will not panic or trigger UB, but may lead to duplicate stats being stored. - pub fn new_unchecked(values: Vec<(Stat, Scalar)>) -> Self { + pub fn new_unchecked(values: Vec<(Stat, ScalarValue)>) -> Self { Self { values: Some(values), } } - /// Create a new, empty StatsSet. - pub fn empty() -> Self { - Self { values: None } - } - /// Specialized constructor for the case where the StatsSet represents /// an array consisting entirely of [null](vortex_dtype::DType::Null) values. pub fn nulls(len: usize, dtype: &DType) -> Self { @@ -61,7 +56,8 @@ impl StatsSet { stats } - pub fn constant(scalar: &Scalar, length: usize) -> Self { + pub fn constant(scalar: Scalar, length: usize) -> Self { + let (dtype, sv) = scalar.into_parts(); let mut stats = Self::default(); if length > 0 { stats.set(Stat::IsConstant, true); @@ -72,22 +68,22 @@ impl StatsSet { let run_count = if length == 0 { 0u64 } else { 1 }; stats.set(Stat::RunCount, run_count); - let null_count = if scalar.is_null() { length as u64 } else { 0 }; + let null_count = if sv.is_null() { length as u64 } else { 0 }; stats.set(Stat::NullCount, null_count); - if let Some(bool_scalar) = scalar.as_bool_opt() { - let true_count = bool_scalar - .value() + if !sv.is_null() { + stats.set(Stat::Min, sv.clone()); + stats.set(Stat::Max, sv.clone()); + } + + if matches!(dtype, DType::Bool(_)) { + let bool_val = >::try_from(&sv).vortex_expect("Checked dtype"); + let true_count = bool_val .map(|b| if b { length as u64 } else { 0 }) .unwrap_or(0); stats.set(Stat::TrueCount, true_count); } - if !scalar.is_null() { - stats.set(Stat::Min, scalar.clone()); - stats.set(Stat::Max, scalar.clone()); - } - stats } @@ -108,7 +104,7 @@ impl StatsSet { ]) } - pub fn of>(stat: Stat, value: S) -> Self { + pub fn of>(stat: Stat, value: S) -> Self { Self::new_unchecked(vec![(stat, value.into())]) } } @@ -125,13 +121,13 @@ impl StatsSet { self.values.as_ref().is_none_or(|v| v.is_empty()) } - pub fn get(&self, stat: Stat) -> Option<&Scalar> { + pub fn get(&self, stat: Stat) -> Option<&ScalarValue> { self.values .as_ref() .and_then(|v| v.iter().find(|(s, _)| *s == stat).map(|(_, v)| v)) } - pub fn get_as TryFrom<&'a Scalar, Error = VortexError>>( + pub fn get_as TryFrom<&'a ScalarValue, Error = VortexError>>( &self, stat: Stat, ) -> Option { @@ -148,7 +144,7 @@ impl StatsSet { } /// Set the stat `stat` to `value`. - pub fn set>(&mut self, stat: Stat, value: S) { + pub fn set>(&mut self, stat: Stat, value: S) { if self.values.is_none() { self.values = Some(Vec::with_capacity(Stat::CARDINALITY)); } @@ -176,7 +172,7 @@ impl StatsSet { /// Iterate over the statistic names and values in-place. /// /// See [Iterator]. - pub fn iter(&self) -> impl Iterator { + pub fn iter(&self) -> impl Iterator { self.values.iter().flat_map(|v| v.iter()) } } @@ -186,10 +182,10 @@ impl StatsSet { /// Owned iterator over the stats. /// /// See [IntoIterator]. -pub struct StatsSetIntoIter(Option>); +pub struct StatsSetIntoIter(Option>); impl Iterator for StatsSetIntoIter { - type Item = (Stat, Scalar); + type Item = (Stat, ScalarValue); fn next(&mut self) -> Option { self.0.as_mut().and_then(|i| i.next()) @@ -197,7 +193,7 @@ impl Iterator for StatsSetIntoIter { } impl IntoIterator for StatsSet { - type Item = (Stat, Scalar); + type Item = (Stat, ScalarValue); type IntoIter = StatsSetIntoIter; fn into_iter(self) -> Self::IntoIter { @@ -205,8 +201,8 @@ impl IntoIterator for StatsSet { } } -impl FromIterator<(Stat, Scalar)> for StatsSet { - fn from_iter>(iter: T) -> Self { +impl FromIterator<(Stat, ScalarValue)> for StatsSet { + fn from_iter>(iter: T) -> Self { let iter = iter.into_iter(); let (lower_bound, _) = iter.size_hint(); let mut this = Self { @@ -217,9 +213,9 @@ impl FromIterator<(Stat, Scalar)> for StatsSet { } } -impl Extend<(Stat, Scalar)> for StatsSet { +impl Extend<(Stat, ScalarValue)> for StatsSet { #[inline] - fn extend>(&mut self, iter: T) { + fn extend>(&mut self, iter: T) { let iter = iter.into_iter(); let (lower_bound, _) = iter.size_hint(); if let Some(v) = &mut self.values { @@ -233,16 +229,16 @@ impl Extend<(Stat, Scalar)> for StatsSet { impl StatsSet { /// Merge stats set `other` into `self`, with the semantic assumption that `other` /// contains stats from an array that is *appended* to the array represented by `self`. - pub fn merge_ordered(mut self, other: &Self) -> Self { + pub fn merge_ordered(mut self, other: &Self, dtype: &DType) -> Self { for s in all::() { match s { Stat::BitWidthFreq => self.merge_bit_width_freq(other), Stat::TrailingZeroFreq => self.merge_trailing_zero_freq(other), - Stat::IsConstant => self.merge_is_constant(other), - Stat::IsSorted => self.merge_is_sorted(other), - Stat::IsStrictSorted => self.merge_is_strict_sorted(other), - Stat::Max => self.merge_max(other), - Stat::Min => self.merge_min(other), + Stat::IsConstant => self.merge_is_constant(other, dtype), + Stat::IsSorted => self.merge_is_sorted(other, dtype), + Stat::IsStrictSorted => self.merge_is_strict_sorted(other, dtype), + Stat::Max => self.merge_max(other, dtype), + Stat::Min => self.merge_min(other, dtype), Stat::RunCount => self.merge_run_count(other), Stat::TrueCount => self.merge_true_count(other), Stat::NullCount => self.merge_null_count(other), @@ -255,7 +251,7 @@ impl StatsSet { /// Merge stats set `other` into `self`, with no assumption on ordering. /// Stats that are not commutative (e.g., is_sorted) are dropped from the result. - pub fn merge_unordered(mut self, other: &Self) -> Self { + pub fn merge_unordered(mut self, other: &Self, dtype: &DType) -> Self { for s in all::() { if !s.is_commutative() { self.clear(s); @@ -265,9 +261,9 @@ impl StatsSet { match s { Stat::BitWidthFreq => self.merge_bit_width_freq(other), Stat::TrailingZeroFreq => self.merge_trailing_zero_freq(other), - Stat::IsConstant => self.merge_is_constant(other), - Stat::Max => self.merge_max(other), - Stat::Min => self.merge_min(other), + Stat::IsConstant => self.merge_is_constant(other, dtype), + Stat::Max => self.merge_max(other, dtype), + Stat::Min => self.merge_min(other, dtype), Stat::TrueCount => self.merge_true_count(other), Stat::NullCount => self.merge_null_count(other), Stat::UncompressedSizeInBytes => self.merge_uncompressed_size_in_bytes(other), @@ -278,10 +274,10 @@ impl StatsSet { self } - fn merge_min(&mut self, other: &Self) { + fn merge_min(&mut self, other: &Self, dtype: &DType) { match (self.get(Stat::Min), other.get(Stat::Min)) { (Some(m1), Some(m2)) => { - if m2 < m1 { + if Scalar::new(dtype.clone(), m2.clone()) < Scalar::new(dtype.clone(), m1.clone()) { self.set(Stat::Min, m2.clone()); } } @@ -289,10 +285,10 @@ impl StatsSet { } } - fn merge_max(&mut self, other: &Self) { + fn merge_max(&mut self, other: &Self, dtype: &DType) { match (self.get(Stat::Max), other.get(Stat::Max)) { (Some(m1), Some(m2)) => { - if m2 > m1 { + if Scalar::new(dtype.clone(), m2.clone()) > Scalar::new(dtype.clone(), m1.clone()) { self.set(Stat::Max, m2.clone()); } } @@ -300,10 +296,20 @@ impl StatsSet { } } - fn merge_is_constant(&mut self, other: &Self) { + fn merge_is_constant(&mut self, other: &Self, dtype: &DType) { if let Some(is_constant) = self.get_as(Stat::IsConstant) { if let Some(other_is_constant) = other.get_as(Stat::IsConstant) { - if is_constant && other_is_constant && self.get(Stat::Min) == other.get(Stat::Min) { + if is_constant + && other_is_constant + && self + .get(Stat::Min) + .cloned() + .map(|sv| Scalar::new(dtype.clone(), sv)) + == other + .get(Stat::Min) + .cloned() + .map(|sv| Scalar::new(dtype.clone(), sv)) + { return; } } @@ -311,18 +317,19 @@ impl StatsSet { } } - fn merge_is_sorted(&mut self, other: &Self) { - self.merge_sortedness_stat(other, Stat::IsSorted, |own, other| own <= other) + fn merge_is_sorted(&mut self, other: &Self, dtype: &DType) { + self.merge_sortedness_stat(other, Stat::IsSorted, dtype, |own, other| own <= other) } - fn merge_is_strict_sorted(&mut self, other: &Self) { - self.merge_sortedness_stat(other, Stat::IsStrictSorted, |own, other| own < other) + fn merge_is_strict_sorted(&mut self, other: &Self, dtype: &DType) { + self.merge_sortedness_stat(other, Stat::IsStrictSorted, dtype, |own, other| own < other) } - fn merge_sortedness_stat, Option<&Scalar>) -> bool>( + fn merge_sortedness_stat, Option) -> bool>( &mut self, other: &Self, stat: Stat, + dtype: &DType, cmp: F, ) { if let Some(is_sorted) = self.get_as(stat) { @@ -331,7 +338,15 @@ impl StatsSet { self.clear(stat); } else if is_sorted && other_is_sorted - && cmp(self.get(Stat::Max), other.get(Stat::Min)) + && cmp( + self.get(Stat::Max) + .cloned() + .map(|sv| Scalar::new(dtype.clone(), sv)), + other + .get(Stat::Min) + .cloned() + .map(|sv| Scalar::new(dtype.clone(), sv)), + ) { return; } else { @@ -411,6 +426,7 @@ impl StatsSet { mod test { use enum_iterator::all; use itertools::Itertools; + use vortex_dtype::{DType, Nullability, PType}; use crate::array::PrimitiveArray; use crate::stats::{ArrayStatistics as _, Stat, StatsSet}; @@ -419,109 +435,156 @@ mod test { #[test] fn test_iter() { let set = StatsSet::new_unchecked(vec![(Stat::Max, 100.into()), (Stat::Min, 42.into())]); - assert_eq!( - set.iter().cloned().collect_vec(), - vec![(Stat::Max, 100.into()), (Stat::Min, 42.into())] - ); + let mut iter = set.iter(); + let first = iter.next().unwrap(); + assert_eq!(first.0, Stat::Max); + assert_eq!(i32::try_from(&first.1).unwrap(), 100); + let snd = iter.next().unwrap(); + assert_eq!(snd.0, Stat::Min); + assert_eq!(i32::try_from(&snd.1).unwrap(), 42); } #[test] fn into_iter() { - let set = StatsSet::new_unchecked(vec![(Stat::Max, 100.into()), (Stat::Min, 42.into())]); - assert_eq!( - set.into_iter().collect_vec(), - vec![(Stat::Max, 100.into()), (Stat::Min, 42.into())] - ); + let mut set = + StatsSet::new_unchecked(vec![(Stat::Max, 100.into()), (Stat::Min, 42.into())]) + .into_iter(); + let first = set.next().unwrap(); + assert_eq!(first.0, Stat::Max); + assert_eq!(i32::try_from(&first.1).unwrap(), 100); + let snd = set.next().unwrap(); + assert_eq!(snd.0, Stat::Min); + assert_eq!(i32::try_from(&snd.1).unwrap(), 42); } #[test] fn merge_into_min() { - let first = StatsSet::of(Stat::Min, 42).merge_ordered(&StatsSet::default()); - assert_eq!(first.get(Stat::Min), None); + let first = StatsSet::of(Stat::Min, 42).merge_ordered( + &StatsSet::default(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::Min).is_none()); } #[test] fn merge_from_min() { - let first = StatsSet::default().merge_ordered(&StatsSet::of(Stat::Min, 42)); - assert_eq!(first.get(Stat::Min), None); + let first = StatsSet::default().merge_ordered( + &StatsSet::of(Stat::Min, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::Min).is_none()); } #[test] fn merge_mins() { - let first = StatsSet::of(Stat::Min, 37).merge_ordered(&StatsSet::of(Stat::Min, 42)); - assert_eq!(first.get(Stat::Min).cloned(), Some(37.into())); + let first = StatsSet::of(Stat::Min, 37).merge_ordered( + &StatsSet::of(Stat::Min, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert_eq!(first.get_as::(Stat::Min), Some(37)); } #[test] fn merge_into_max() { - let first = StatsSet::of(Stat::Max, 42).merge_ordered(&StatsSet::default()); - assert_eq!(first.get(Stat::Max), None); + let first = StatsSet::of(Stat::Max, 42).merge_ordered( + &StatsSet::default(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::Max).is_none()); } #[test] fn merge_from_max() { - let first = StatsSet::default().merge_ordered(&StatsSet::of(Stat::Max, 42)); - assert_eq!(first.get(Stat::Max), None); + let first = StatsSet::default().merge_ordered( + &StatsSet::of(Stat::Max, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::Max).is_none()); } #[test] fn merge_maxes() { - let first = StatsSet::of(Stat::Max, 37).merge_ordered(&StatsSet::of(Stat::Max, 42)); - assert_eq!(first.get(Stat::Max).cloned(), Some(42.into())); + let first = StatsSet::of(Stat::Max, 37).merge_ordered( + &StatsSet::of(Stat::Max, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert_eq!(first.get_as::(Stat::Max), Some(42)); } #[test] fn merge_into_scalar() { - let first = StatsSet::of(Stat::TrueCount, 42).merge_ordered(&StatsSet::default()); - assert_eq!(first.get(Stat::TrueCount), None); + let first = StatsSet::of(Stat::TrueCount, 42).merge_ordered( + &StatsSet::default(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::TrueCount).is_none()); } #[test] fn merge_from_scalar() { - let first = StatsSet::default().merge_ordered(&StatsSet::of(Stat::TrueCount, 42)); - assert_eq!(first.get(Stat::TrueCount), None); + let first = StatsSet::default().merge_ordered( + &StatsSet::of(Stat::TrueCount, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::TrueCount).is_none()); } #[test] fn merge_scalars() { - let first = - StatsSet::of(Stat::TrueCount, 37).merge_ordered(&StatsSet::of(Stat::TrueCount, 42)); - assert_eq!(first.get(Stat::TrueCount).cloned(), Some(79u64.into())); + let first = StatsSet::of(Stat::TrueCount, 37).merge_ordered( + &StatsSet::of(Stat::TrueCount, 42), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert_eq!(first.get_as::(Stat::TrueCount), Some(79)); } #[test] fn merge_into_freq() { let vec = (0usize..255).collect_vec(); - let first = StatsSet::of(Stat::BitWidthFreq, vec).merge_ordered(&StatsSet::default()); - assert_eq!(first.get(Stat::BitWidthFreq), None); + let first = StatsSet::of(Stat::BitWidthFreq, vec).merge_ordered( + &StatsSet::default(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::BitWidthFreq).is_none()); } #[test] fn merge_from_freq() { let vec = (0usize..255).collect_vec(); - let first = StatsSet::default().merge_ordered(&StatsSet::of(Stat::BitWidthFreq, vec)); - assert_eq!(first.get(Stat::BitWidthFreq), None); + let first = StatsSet::default().merge_ordered( + &StatsSet::of(Stat::BitWidthFreq, vec), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::BitWidthFreq).is_none()); } #[test] fn merge_freqs() { let vec_in = vec![5u64; 256]; let vec_out = vec![10u64; 256]; - let first = StatsSet::of(Stat::BitWidthFreq, vec_in.clone()) - .merge_ordered(&StatsSet::of(Stat::BitWidthFreq, vec_in)); - assert_eq!(first.get(Stat::BitWidthFreq).cloned(), Some(vec_out.into())); + let first = StatsSet::of(Stat::BitWidthFreq, vec_in.clone()).merge_ordered( + &StatsSet::of(Stat::BitWidthFreq, vec_in), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert_eq!(first.get_as::>(Stat::BitWidthFreq), Some(vec_out)); } #[test] fn merge_into_sortedness() { - let first = StatsSet::of(Stat::IsStrictSorted, true).merge_ordered(&StatsSet::default()); - assert_eq!(first.get(Stat::IsStrictSorted), None); + let first = StatsSet::of(Stat::IsStrictSorted, true).merge_ordered( + &StatsSet::default(), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::IsStrictSorted).is_none()); } #[test] fn merge_from_sortedness() { - let first = StatsSet::default().merge_ordered(&StatsSet::of(Stat::IsStrictSorted, true)); - assert_eq!(first.get(Stat::IsStrictSorted), None); + let first = StatsSet::default().merge_ordered( + &StatsSet::of(Stat::IsStrictSorted, true), + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::IsStrictSorted).is_none()); } #[test] @@ -530,8 +593,11 @@ mod test { first.set(Stat::Max, 1); let mut second = StatsSet::of(Stat::IsStrictSorted, true); second.set(Stat::Min, 2); - first = first.merge_ordered(&second); - assert_eq!(first.get(Stat::IsStrictSorted).cloned(), Some(true.into())); + first = first.merge_ordered( + &second, + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert_eq!(first.get_as::(Stat::IsStrictSorted), Some(true)); } #[test] @@ -540,11 +606,11 @@ mod test { first.set(Stat::Min, 1); let mut second = StatsSet::of(Stat::IsStrictSorted, true); second.set(Stat::Max, 2); - second = second.merge_ordered(&first); - assert_eq!( - second.get(Stat::IsStrictSorted).cloned(), - Some(false.into()) + second = second.merge_ordered( + &first, + &DType::Primitive(PType::I32, Nullability::NonNullable), ); + assert_eq!(second.get_as::(Stat::IsStrictSorted), Some(false)); } #[test] @@ -553,11 +619,11 @@ mod test { first.set(Stat::Max, 1); let mut second = StatsSet::of(Stat::IsStrictSorted, false); second.set(Stat::Min, 2); - first.merge_ordered(&second); - assert_eq!( - second.get(Stat::IsStrictSorted).cloned(), - Some(false.into()) + first.merge_ordered( + &second, + &DType::Primitive(PType::I32, Nullability::NonNullable), ); + assert_eq!(second.get_as::(Stat::IsStrictSorted), Some(false)); } #[test] @@ -565,8 +631,11 @@ mod test { let mut first = StatsSet::of(Stat::IsStrictSorted, true); first.set(Stat::Max, 1); let second = StatsSet::of(Stat::IsStrictSorted, true); - first = first.merge_ordered(&second); - assert_eq!(first.get(Stat::IsStrictSorted).cloned(), None); + first = first.merge_ordered( + &second, + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); + assert!(first.get(Stat::IsStrictSorted).is_none()); } #[test] @@ -584,7 +653,10 @@ mod test { assert!(stats.get(*stat).is_some(), "Stat {} is missing", stat); } - let merged = stats.clone().merge_unordered(&stats); + let merged = stats.clone().merge_unordered( + &stats, + &DType::Primitive(PType::I32, Nullability::NonNullable), + ); for stat in &all_stats { assert_eq!( merged.get(*stat).is_some(), @@ -594,21 +666,17 @@ mod test { ) } - assert_eq!(merged.get(Stat::Min), stats.get(Stat::Min)); - assert_eq!(merged.get(Stat::Max), stats.get(Stat::Max)); assert_eq!( - merged - .get(Stat::NullCount) - .unwrap() - .as_primitive() - .typed_value::() - .unwrap(), - 2 * stats - .get(Stat::NullCount) - .unwrap() - .as_primitive() - .typed_value::() - .unwrap() + merged.get_as::(Stat::Min), + stats.get_as::(Stat::Min) + ); + assert_eq!( + merged.get_as::(Stat::Max), + stats.get_as::(Stat::Max) + ); + assert_eq!( + merged.get_as::(Stat::NullCount).unwrap(), + 2 * stats.get_as::(Stat::NullCount).unwrap() ); } } diff --git a/vortex-array/src/stream/take_rows.rs b/vortex-array/src/stream/take_rows.rs index c0f6c1f0f9..7caa741855 100644 --- a/vortex-array/src/stream/take_rows.rs +++ b/vortex-array/src/stream/take_rows.rs @@ -44,7 +44,7 @@ impl TakeRows { if indices.dtype().is_signed_int() && indices .statistics() - .compute_as_cast::(Stat::Min) + .compute_as::(Stat::Min) .map(|min| min < 0) .unwrap_or(true) { diff --git a/vortex-datafusion/Cargo.toml b/vortex-datafusion/Cargo.toml index d87eb64d09..6b1d4b5f1e 100644 --- a/vortex-datafusion/Cargo.toml +++ b/vortex-datafusion/Cargo.toml @@ -42,7 +42,7 @@ vortex-error = { workspace = true, features = ["datafusion"] } vortex-expr = { workspace = true, features = ["datafusion"] } vortex-file = { workspace = true, features = ["object_store", "tokio"] } vortex-io = { workspace = true, features = ["object_store", "tokio"] } -vortex-scan = { workspace = true } +vortex-scalar = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/vortex-datafusion/src/memory/statistics.rs b/vortex-datafusion/src/memory/statistics.rs index 565449c450..1d2bf848a9 100644 --- a/vortex-datafusion/src/memory/statistics.rs +++ b/vortex-datafusion/src/memory/statistics.rs @@ -4,9 +4,10 @@ use itertools::Itertools; use vortex_array::array::ChunkedArray; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::variants::StructArrayTrait; -use vortex_array::ArrayLen; +use vortex_array::{ArrayDType, ArrayLen}; use vortex_dtype::FieldNames; use vortex_error::{vortex_err, VortexExpect, VortexResult}; +use vortex_scalar::Scalar; pub(crate) fn chunked_array_df_stats( array: &ChunkedArray, @@ -32,6 +33,7 @@ pub(crate) fn chunked_array_df_stats( max_value: arr .statistics() .get(Stat::Max) + .map(|n| Scalar::new(array.dtype().clone(), n)) .map(|n| { ScalarValue::try_from(n).vortex_expect("cannot convert scalar to df scalar") }) @@ -40,6 +42,7 @@ pub(crate) fn chunked_array_df_stats( min_value: arr .statistics() .get(Stat::Min) + .map(|n| Scalar::new(array.dtype().clone(), n)) .map(|n| { ScalarValue::try_from(n).vortex_expect("cannot convert scalar to df scalar") }) diff --git a/vortex-datafusion/src/persistent/format.rs b/vortex-datafusion/src/persistent/format.rs index a5e1078323..e8bc8f2290 100644 --- a/vortex-datafusion/src/persistent/format.rs +++ b/vortex-datafusion/src/persistent/format.rs @@ -18,13 +18,14 @@ use datafusion_physical_plan::metrics::ExecutionPlanMetricsSet; use datafusion_physical_plan::ExecutionPlan; use futures::{stream, StreamExt as _, TryStreamExt as _}; use object_store::{ObjectMeta, ObjectStore}; -use vortex_array::arrow::infer_schema; +use vortex_array::arrow::{infer_schema, FromArrowType}; use vortex_array::stats::Stat; use vortex_array::ContextRef; -use vortex_dtype::FieldPath; +use vortex_dtype::{DType, FieldPath}; use vortex_error::{vortex_err, VortexExpect, VortexResult}; use vortex_file::{VortexOpenOptions, VORTEX_FILE_EXTENSION}; use vortex_io::ObjectStoreReadAt; +use vortex_scalar::Scalar; use super::cache::FileLayoutCache; use super::execution::VortexExec; @@ -176,15 +177,18 @@ impl FileFormat for VortexFormat { let column_statistics = stats .into_iter() - .map(|s| { + .zip(table_schema.fields().iter()) + .map(|(s, f)| { let null_count = s.get_as::(Stat::NullCount); let min = s .get(Stat::Min) .cloned() + .map(|n| Scalar::new(DType::from_arrow(f.as_ref()), n)) .and_then(|s| ScalarValue::try_from(s).ok()); let max = s .get(Stat::Max) .cloned() + .map(|n| Scalar::new(DType::from_arrow(f.as_ref()), n)) .and_then(|s| ScalarValue::try_from(s).ok()); ColumnStatistics { null_count: null_count diff --git a/vortex-layout/src/layouts/chunked/eval_expr.rs b/vortex-layout/src/layouts/chunked/eval_expr.rs index e24d75125e..c5141d13f0 100644 --- a/vortex-layout/src/layouts/chunked/eval_expr.rs +++ b/vortex-layout/src/layouts/chunked/eval_expr.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use futures::future::{ready, try_join_all}; use futures::FutureExt; use vortex_array::array::{ChunkedArray, ConstantArray}; -use vortex_array::{ArrayDType, ArrayData, Canonical, IntoArrayData}; +use vortex_array::{ArrayData, IntoArrayData}; use vortex_error::VortexResult; use vortex_expr::ExprRef; use vortex_scalar::Scalar; @@ -20,10 +20,7 @@ impl ExprEvaluator for ChunkedReader { expr: ExprRef, ) -> VortexResult { // Compute the result dtype of the expression. - let dtype = expr - .evaluate(&Canonical::empty(self.dtype())?.into_array())? - .dtype() - .clone(); + let dtype = expr.return_dtype(self.dtype())?; // First we need to compute the pruning mask let pruning_mask = self.pruning_mask(&expr).await?; diff --git a/vortex-layout/src/layouts/chunked/eval_stats.rs b/vortex-layout/src/layouts/chunked/eval_stats.rs index 465cbc08a0..2f3011f175 100644 --- a/vortex-layout/src/layouts/chunked/eval_stats.rs +++ b/vortex-layout/src/layouts/chunked/eval_stats.rs @@ -21,7 +21,7 @@ impl StatsEvaluator for ChunkedReader { // Otherwise, fetch the stats table let Some(stats_table) = self.stats_table().await? else { - return Ok(vec![StatsSet::empty(); field_paths.len()]); + return Ok(vec![StatsSet::default(); field_paths.len()]); }; let mut stat_sets = Vec::with_capacity(field_paths.len()); @@ -30,7 +30,7 @@ impl StatsEvaluator for ChunkedReader { // TODO(ngates): the stats table only stores a single array, so we can only answer // stats if the field path == root. // See for more details. - stat_sets.push(StatsSet::empty()); + stat_sets.push(StatsSet::default()); continue; } stat_sets.push(stats_table.to_stats_set(&stats)?); diff --git a/vortex-layout/src/layouts/chunked/stats_table.rs b/vortex-layout/src/layouts/chunked/stats_table.rs index 443187d68f..e26a9bb4f4 100644 --- a/vortex-layout/src/layouts/chunked/stats_table.rs +++ b/vortex-layout/src/layouts/chunked/stats_table.rs @@ -9,6 +9,7 @@ use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{DType, Nullability, PType, StructDType}; use vortex_error::{vortex_bail, VortexExpect, VortexResult}; +use vortex_scalar::Scalar; /// A table of statistics for a column. /// Each row of the stats table corresponds to a chunk of the column. @@ -149,7 +150,7 @@ impl StatsAccumulator { pub fn push_chunk(&mut self, array: &ArrayData) -> VortexResult<()> { for (s, builder) in self.stats.iter().zip_eq(self.builders.iter_mut()) { if let Some(v) = array.statistics().compute(*s) { - builder.append_scalar(&v.cast(builder.dtype())?)?; + builder.append_scalar(&Scalar::new(s.dtype(array.dtype()), v))?; } else { builder.append_null(); } diff --git a/vortex-layout/src/layouts/struct_/eval_stats.rs b/vortex-layout/src/layouts/struct_/eval_stats.rs index 96b451e842..352bd0ed3a 100644 --- a/vortex-layout/src/layouts/struct_/eval_stats.rs +++ b/vortex-layout/src/layouts/struct_/eval_stats.rs @@ -21,7 +21,7 @@ impl StatsEvaluator for StructReader { for path in field_paths.iter() { if path.is_root() { // We don't have any stats for a struct layout - futures.push(ready(Ok(vec![StatsSet::empty()])).boxed()); + futures.push(ready(Ok(vec![StatsSet::default()])).boxed()); } else { // Otherwise, strip off the first path element and delegate to the child layout let Field::Name(field) = path.path()[0] diff --git a/vortex-sampling-compressor/Cargo.toml b/vortex-sampling-compressor/Cargo.toml index 9c24323746..bdca4bf31f 100644 --- a/vortex-sampling-compressor/Cargo.toml +++ b/vortex-sampling-compressor/Cargo.toml @@ -32,6 +32,7 @@ vortex-fastlanes = { workspace = true } vortex-fsst = { workspace = true } vortex-runend = { workspace = true } vortex-sparse = { workspace = true } +vortex-scalar = { workspace = true } vortex-zigzag = { workspace = true } [dev-dependencies] diff --git a/vortex-sampling-compressor/src/compressors/zigzag.rs b/vortex-sampling-compressor/src/compressors/zigzag.rs index 99ad2f8134..c7afdc499f 100644 --- a/vortex-sampling-compressor/src/compressors/zigzag.rs +++ b/vortex-sampling-compressor/src/compressors/zigzag.rs @@ -35,7 +35,7 @@ impl EncodingCompressor for ZigZagCompressor { // TODO(ngates): also check that Stat::Max is less than half the max value of the type parray .statistics() - .compute_as_cast::(Stat::Min) + .compute_as::(Stat::Min) .filter(|&min| min < 0) .map(|_| self as &dyn EncodingCompressor) } diff --git a/vortex-sampling-compressor/src/downscale.rs b/vortex-sampling-compressor/src/downscale.rs index d0544818f6..34deab682f 100644 --- a/vortex-sampling-compressor/src/downscale.rs +++ b/vortex-sampling-compressor/src/downscale.rs @@ -4,7 +4,7 @@ use vortex_array::encoding::EncodingVTable; use vortex_array::stats::{ArrayStatistics, Stat}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_dtype::{DType, PType}; -use vortex_error::{vortex_err, VortexResult}; +use vortex_error::{vortex_err, VortexExpect, VortexResult}; /// Downscale a primitive array to the narrowest PType that fits all the values. pub fn downscale_integer_array(array: ArrayData) -> VortexResult { @@ -12,7 +12,7 @@ pub fn downscale_integer_array(array: ArrayData) -> VortexResult { // This can happen if e.g. the array is ConstantArray. return Ok(array); } - let array = PrimitiveArray::try_from(array)?; + let array = PrimitiveArray::maybe_from(array).vortex_expect("Checked earlier"); let min = array .statistics() @@ -25,15 +25,14 @@ pub fn downscale_integer_array(array: ArrayData) -> VortexResult { // If we can't cast to i64, then leave the array as its original type. // It's too big to downcast anyway. - let Ok(min) = min.cast(&DType::Primitive(PType::I64, array.dtype().nullability())) else { + let Ok(min) = i64::try_from(&min) else { return Ok(array.into_array()); }; - let Ok(max) = max.cast(&DType::Primitive(PType::I64, array.dtype().nullability())) else { + let Ok(max) = i64::try_from(&max) else { return Ok(array.into_array()); }; - downscale_primitive_integer_array(array, i64::try_from(min)?, i64::try_from(max)?) - .map(|a| a.into_array()) + downscale_primitive_integer_array(array, min, max).map(|a| a.into_array()) } /// Downscale a primitive array to the narrowest PType that fits all the values. diff --git a/vortex-sampling-compressor/tests/smoketest.rs b/vortex-sampling-compressor/tests/smoketest.rs index 39b74d6b2a..8d9ec97ba5 100644 --- a/vortex-sampling-compressor/tests/smoketest.rs +++ b/vortex-sampling-compressor/tests/smoketest.rs @@ -22,7 +22,6 @@ mod tests { use vortex_fastlanes::BitPackedEncoding; use vortex_fsst::FSSTEncoding; use vortex_sampling_compressor::ALL_COMPRESSORS; - use vortex_scalar::Scalar; use super::*; @@ -125,8 +124,10 @@ mod tests { for chunk in prim_col.chunks() { assert_eq!(chunk.encoding().id(), BitPackedEncoding::ID); assert_eq!( - chunk.statistics().get(Stat::UncompressedSizeInBytes), - Some(Scalar::from((chunk.len() * 8) as u64 + 1)) + chunk + .statistics() + .get_as::(Stat::UncompressedSizeInBytes), + Some((chunk.len() * 8) + 1) ); } @@ -138,8 +139,10 @@ mod tests { for chunk in bool_col.chunks() { assert_eq!(chunk.encoding().id(), BoolEncoding::ID); assert_eq!( - chunk.statistics().get(Stat::UncompressedSizeInBytes), - Some(Scalar::from(chunk.len().div_ceil(8) as u64 + 2)) + chunk + .statistics() + .get_as::(Stat::UncompressedSizeInBytes), + Some(chunk.len().div_ceil(8) + 2) ); } @@ -154,8 +157,10 @@ mod tests { || chunk.encoding().id() == FSSTEncoding::ID ); assert_eq!( - chunk.statistics().get(Stat::UncompressedSizeInBytes), - Some(Scalar::from(1392641_u64)) + chunk + .statistics() + .get_as::(Stat::UncompressedSizeInBytes), + Some(1392641_usize) ); } @@ -167,8 +172,10 @@ mod tests { for chunk in binary_col.chunks() { assert_eq!(chunk.encoding().id(), VarBinEncoding::ID); assert_eq!( - chunk.statistics().get(Stat::UncompressedSizeInBytes), - Some(Scalar::from(134357007_u64)) + chunk + .statistics() + .get_as::(Stat::UncompressedSizeInBytes), + Some(134357007_usize) ); } @@ -180,8 +187,10 @@ mod tests { for chunk in timestamp_col.chunks() { assert_eq!(chunk.encoding().id(), DateTimePartsEncoding::ID); assert_eq!( - chunk.statistics().get(Stat::UncompressedSizeInBytes), - Some((chunk.len() * 8 + 4).into()) + chunk + .statistics() + .get_as::(Stat::UncompressedSizeInBytes), + Some(chunk.len() * 8 + 4) ) } } diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index da532911b3..f422a4e4ae 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -4,8 +4,7 @@ use vortex_buffer::ByteBuffer; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; -use crate::value::{InnerScalarValue, ScalarValue}; -use crate::Scalar; +use crate::{InnerScalarValue, Scalar, ScalarValue}; #[derive(Debug, Hash)] pub struct BinaryScalar<'a> { diff --git a/vortex-scalar/src/bool.rs b/vortex-scalar/src/bool.rs index 0e9d76dfdb..f3182e3fe0 100644 --- a/vortex-scalar/src/bool.rs +++ b/vortex-scalar/src/bool.rs @@ -4,8 +4,7 @@ use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; -use crate::value::ScalarValue; -use crate::{InnerScalarValue, Scalar}; +use crate::{InnerScalarValue, Scalar, ScalarValue}; #[derive(Debug, Hash)] pub struct BoolScalar<'a> { diff --git a/vortex-scalar/src/extension.rs b/vortex-scalar/src/extension.rs index 0b933cb618..6b24fc66c7 100644 --- a/vortex-scalar/src/extension.rs +++ b/vortex-scalar/src/extension.rs @@ -4,8 +4,7 @@ use std::sync::Arc; use vortex_dtype::{DType, ExtDType}; use vortex_error::{vortex_bail, VortexError, VortexResult}; -use crate::value::ScalarValue; -use crate::Scalar; +use crate::{Scalar, ScalarValue}; pub struct ExtScalar<'a> { ext_dtype: &'a ExtDType, diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 69ea785ae3..bb0e47af94 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -19,11 +19,11 @@ mod null; mod primitive; mod pvalue; mod scalar_type; +mod scalarvalue; #[cfg(feature = "serde")] mod serde; mod struct_; mod utf8; -mod value; pub use binary::*; pub use bool::*; @@ -31,9 +31,9 @@ pub use extension::*; pub use list::*; pub use primitive::*; pub use pvalue::*; +pub use scalarvalue::*; pub use struct_::*; pub use utf8::*; -pub use value::*; use vortex_error::{vortex_bail, VortexExpect, VortexResult}; /// A single logical item, composed of both a [`ScalarValue`] and a logical [`DType`]. diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index 2fe220253b..1d437df8eb 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -8,8 +8,7 @@ use vortex_error::{ vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult, }; -use crate::value::{InnerScalarValue, ScalarValue}; -use crate::Scalar; +use crate::{InnerScalarValue, Scalar, ScalarValue}; pub struct ListScalar<'a> { dtype: &'a DType, diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index 158f8a8c97..d0fe2b1098 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -1,6 +1,7 @@ use std::any::type_name; use std::cmp::Ordering; use std::fmt::{Debug, Display}; +use std::ops::Sub; use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, FromPrimitive}; use vortex_dtype::half::f16; @@ -11,8 +12,7 @@ use vortex_error::{ }; use crate::pvalue::PValue; -use crate::value::ScalarValue; -use crate::{InnerScalarValue, Scalar}; +use crate::{InnerScalarValue, Scalar, ScalarValue}; #[derive(Debug, Clone, Copy, Hash)] pub struct PrimitiveScalar<'a> { @@ -188,7 +188,7 @@ impl<'a> TryFrom<&'a Scalar> for PrimitiveScalar<'a> { } } -impl std::ops::Sub for PrimitiveScalar<'_> { +impl Sub for PrimitiveScalar<'_> { type Output = VortexResult; fn sub(self, rhs: Self) -> Self::Output { @@ -236,13 +236,6 @@ impl Scalar { .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)), ) } - - pub fn zero>(nullability: Nullability) -> Self { - Self { - dtype: DType::Primitive(T::PTYPE, nullability), - value: ScalarValue(InnerScalarValue::Primitive(T::zero().into())), - } - } } macro_rules! primitive_scalar { @@ -453,8 +446,7 @@ mod tests { use vortex_dtype::{DType, Nullability, PType}; use vortex_error::VortexError; - use crate::value::InnerScalarValue; - use crate::{PValue, PrimitiveScalar, ScalarValue}; + use crate::{InnerScalarValue, PValue, PrimitiveScalar, ScalarValue}; #[test] fn test_integer_subtract() { diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/pvalue.rs index da9de924aa..705a4049e0 100644 --- a/vortex-scalar/src/pvalue.rs +++ b/vortex-scalar/src/pvalue.rs @@ -43,6 +43,8 @@ impl PartialEq for PValue { } } +impl Eq for PValue {} + impl PartialOrd for PValue { fn partial_cmp(&self, other: &Self) -> Option { match (self, other) { diff --git a/vortex-scalar/src/scalarvalue/binary.rs b/vortex-scalar/src/scalarvalue/binary.rs new file mode 100644 index 0000000000..eccfcd2d90 --- /dev/null +++ b/vortex-scalar/src/scalarvalue/binary.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use vortex_buffer::ByteBuffer; +use vortex_error::{VortexError, VortexExpect, VortexResult}; + +use crate::scalarvalue::InnerScalarValue; +use crate::ScalarValue; + +impl<'a> TryFrom<&'a ScalarValue> for ByteBuffer { + type Error = VortexError; + + fn try_from(scalar: &'a ScalarValue) -> VortexResult { + Ok(scalar + .as_buffer()? + .vortex_expect("Can't convert null scalar into a byte buffer")) + } +} + +impl<'a> TryFrom<&'a ScalarValue> for Option { + type Error = VortexError; + + fn try_from(scalar: &'a ScalarValue) -> VortexResult { + scalar.as_buffer() + } +} + +impl From<&[u8]> for ScalarValue { + fn from(value: &[u8]) -> Self { + ScalarValue::from(ByteBuffer::from(value.to_vec())) + } +} + +impl From for ScalarValue { + fn from(value: ByteBuffer) -> Self { + ScalarValue(InnerScalarValue::Buffer(Arc::new(value))) + } +} diff --git a/vortex-scalar/src/scalarvalue/bool.rs b/vortex-scalar/src/scalarvalue/bool.rs new file mode 100644 index 0000000000..d9ed805a0e --- /dev/null +++ b/vortex-scalar/src/scalarvalue/bool.rs @@ -0,0 +1,20 @@ +use vortex_error::{vortex_err, VortexError, VortexResult}; + +use crate::ScalarValue; + +impl TryFrom<&ScalarValue> for bool { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> VortexResult { + >::try_from(value)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } +} + +impl TryFrom<&ScalarValue> for Option { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> VortexResult { + value.as_bool() + } +} diff --git a/vortex-scalar/src/scalarvalue/list.rs b/vortex-scalar/src/scalarvalue/list.rs new file mode 100644 index 0000000000..f8e4077c34 --- /dev/null +++ b/vortex-scalar/src/scalarvalue/list.rs @@ -0,0 +1,53 @@ +use std::sync::Arc; + +use vortex_buffer::{BufferString, ByteBuffer}; +use vortex_dtype::half::f16; +use vortex_error::{vortex_err, VortexError}; + +use crate::scalarvalue::InnerScalarValue; +use crate::ScalarValue; + +impl<'a, T: for<'b> TryFrom<&'b ScalarValue, Error = VortexError>> TryFrom<&'a ScalarValue> + for Vec +{ + type Error = VortexError; + + fn try_from(value: &'a ScalarValue) -> Result { + let value = value + .as_list()? + .ok_or_else(|| vortex_err!("Can't convert non list scalar to vec"))?; + + value.iter().map(|v| T::try_from(v)).collect() + } +} + +macro_rules! from_vec_for_scalar_value { + ($T:ty) => { + impl From> for ScalarValue { + fn from(value: Vec<$T>) -> Self { + ScalarValue(InnerScalarValue::List( + value + .into_iter() + .map(ScalarValue::from) + .collect::>(), + )) + } + } + }; +} + +// no From> because it could either be a List or a Buffer +from_vec_for_scalar_value!(u16); +from_vec_for_scalar_value!(u32); +from_vec_for_scalar_value!(u64); +from_vec_for_scalar_value!(usize); // For usize only, we implicitly cast for better ergonomics. +from_vec_for_scalar_value!(i8); +from_vec_for_scalar_value!(i16); +from_vec_for_scalar_value!(i32); +from_vec_for_scalar_value!(i64); +from_vec_for_scalar_value!(f16); +from_vec_for_scalar_value!(f32); +from_vec_for_scalar_value!(f64); +from_vec_for_scalar_value!(String); +from_vec_for_scalar_value!(BufferString); +from_vec_for_scalar_value!(ByteBuffer); diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/scalarvalue/mod.rs similarity index 96% rename from vortex-scalar/src/value.rs rename to vortex-scalar/src/scalarvalue/mod.rs index 718f79593b..94c505e02e 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/scalarvalue/mod.rs @@ -1,3 +1,9 @@ +mod binary; +mod bool; +mod list; +mod primitive; +mod utf8; + use std::fmt::{Display, Write}; use std::sync::Arc; @@ -7,6 +13,7 @@ use vortex_dtype::DType; use vortex_error::{vortex_err, VortexResult}; use crate::pvalue::PValue; +use crate::ScalarType; /// Represents the internal data of a scalar value. Must be interpreted by wrapping /// up with a DType to make a Scalar. @@ -101,7 +108,7 @@ impl Display for InnerScalarValue { } impl ScalarValue { - pub(crate) fn is_null(&self) -> bool { + pub fn is_null(&self) -> bool { self.0.is_null() } @@ -215,6 +222,18 @@ impl InnerScalarValue { } } +impl From> for ScalarValue +where + T: ScalarType, + ScalarValue: From, +{ + fn from(value: Option) -> Self { + value + .map(ScalarValue::from) + .unwrap_or_else(|| ScalarValue(InnerScalarValue::Null)) + } +} + #[cfg(test)] mod test { use std::sync::Arc; diff --git a/vortex-scalar/src/scalarvalue/primitive.rs b/vortex-scalar/src/scalarvalue/primitive.rs new file mode 100644 index 0000000000..5aef2bd0bd --- /dev/null +++ b/vortex-scalar/src/scalarvalue/primitive.rs @@ -0,0 +1,67 @@ +use paste::paste; +use vortex_dtype::half::f16; +use vortex_error::{vortex_err, VortexError}; + +use crate::scalarvalue::InnerScalarValue; +use crate::ScalarValue; + +macro_rules! primitive_scalar { + ($T:ty) => { + impl TryFrom<&ScalarValue> for $T { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + >::try_from(value)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } + } + + impl TryFrom<&ScalarValue> for Option<$T> { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + paste! { + Ok(value.as_pvalue()?.and_then(|v| v.[]())) + } + } + } + + impl From<$T> for ScalarValue { + fn from(value: $T) -> Self { + ScalarValue(InnerScalarValue::Primitive(value.into())) + } + } + }; +} + +primitive_scalar!(u8); +primitive_scalar!(u16); +primitive_scalar!(u32); +primitive_scalar!(u64); +primitive_scalar!(i8); +primitive_scalar!(i16); +primitive_scalar!(i32); +primitive_scalar!(i64); +primitive_scalar!(f16); +primitive_scalar!(f32); +primitive_scalar!(f64); + +/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. +impl TryFrom<&ScalarValue> for usize { + type Error = VortexError; + + fn try_from(value: &ScalarValue) -> Result { + let prim = value + .as_pvalue()? + .and_then(|v| v.as_u64()) + .ok_or_else(|| vortex_err!("cannot convert Null to usize"))?; + Ok(usize::try_from(prim)?) + } +} + +/// Read a scalar as usize. For usize only, we implicitly cast for better ergonomics. +impl From for ScalarValue { + fn from(value: usize) -> Self { + ScalarValue(InnerScalarValue::Primitive((value as u64).into())) + } +} diff --git a/vortex-scalar/src/scalarvalue/utf8.rs b/vortex-scalar/src/scalarvalue/utf8.rs new file mode 100644 index 0000000000..24126adbd3 --- /dev/null +++ b/vortex-scalar/src/scalarvalue/utf8.rs @@ -0,0 +1,55 @@ +use std::sync::Arc; + +use vortex_buffer::BufferString; +use vortex_error::{vortex_err, VortexError, VortexExpect, VortexResult}; + +use crate::scalarvalue::InnerScalarValue; +use crate::ScalarValue; + +impl<'a> TryFrom<&'a ScalarValue> for String { + type Error = VortexError; + + fn try_from(value: &'a ScalarValue) -> Result { + Ok(value + .as_buffer_string()? + .vortex_expect("Can't convert null ScalarValue to String") + .to_string()) + } +} + +impl From<&str> for ScalarValue { + fn from(value: &str) -> Self { + ScalarValue(InnerScalarValue::BufferString(Arc::new( + value.to_string().into(), + ))) + } +} + +impl From for ScalarValue { + fn from(value: String) -> Self { + ScalarValue(InnerScalarValue::BufferString(Arc::new(value.into()))) + } +} + +impl From for ScalarValue { + fn from(value: BufferString) -> Self { + ScalarValue(InnerScalarValue::BufferString(Arc::new(value))) + } +} + +impl<'a> TryFrom<&'a ScalarValue> for BufferString { + type Error = VortexError; + + fn try_from(scalar: &'a ScalarValue) -> VortexResult { + >::try_from(scalar)? + .ok_or_else(|| vortex_err!("Can't extract present value from null scalar")) + } +} + +impl<'a> TryFrom<&'a ScalarValue> for Option { + type Error = VortexError; + + fn try_from(scalar: &'a ScalarValue) -> Result { + scalar.as_buffer_string() + } +} diff --git a/vortex-scalar/src/serde/serde.rs b/vortex-scalar/src/serde/serde.rs index 2f24ff3b2a..7802686385 100644 --- a/vortex-scalar/src/serde/serde.rs +++ b/vortex-scalar/src/serde/serde.rs @@ -7,7 +7,7 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; use vortex_buffer::{BufferString, ByteBuffer}; use crate::pvalue::PValue; -use crate::value::{InnerScalarValue, ScalarValue}; +use crate::{InnerScalarValue, ScalarValue}; impl Serialize for ScalarValue { fn serialize(&self, serializer: S) -> Result diff --git a/vortex-scalar/src/struct_.rs b/vortex-scalar/src/struct_.rs index 7788be30b9..c5fc418ed7 100644 --- a/vortex-scalar/src/struct_.rs +++ b/vortex-scalar/src/struct_.rs @@ -8,8 +8,7 @@ use vortex_error::{ vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect, VortexResult, }; -use crate::value::ScalarValue; -use crate::{InnerScalarValue, Scalar}; +use crate::{InnerScalarValue, Scalar, ScalarValue}; pub struct StructScalar<'a> { dtype: &'a DType, diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index f165883407..6c718c843c 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -5,8 +5,7 @@ use vortex_dtype::Nullability::NonNullable; use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_err, VortexError, VortexExpect as _, VortexResult}; -use crate::value::ScalarValue; -use crate::{InnerScalarValue, Scalar}; +use crate::{InnerScalarValue, Scalar, ScalarValue}; #[derive(Debug, Hash)] pub struct Utf8Scalar<'a> { diff --git a/vortex-scan/src/lib.rs b/vortex-scan/src/lib.rs index 896f6d49e4..758a3e18d3 100644 --- a/vortex-scan/src/lib.rs +++ b/vortex-scan/src/lib.rs @@ -11,7 +11,6 @@ use std::sync::Arc; pub use range_scan::*; pub use row_mask::*; -use vortex_array::{ArrayDType, Canonical, IntoArrayData}; use vortex_dtype::DType; use vortex_error::VortexResult; use vortex_expr::forms::cnf::cnf; @@ -48,10 +47,7 @@ impl Scanner { // TODO(ngates): compute and cache a FieldMask based on the referenced fields. // Where FieldMask ~= Vec - let result_dtype = projection - .evaluate(&Canonical::empty(&dtype)?.into_array())? - .dtype() - .clone(); + let result_dtype = projection.return_dtype(&dtype)?; let conjuncts: Box<[ExprRef]> = if let Some(filter) = filter { let conjuncts = cnf(filter)?;