Skip to content

Commit

Permalink
This is a combination of 5 commits.
Browse files Browse the repository at this point in the history
ScalarValues
  • Loading branch information
robert3005 committed Jan 24, 2025
1 parent 858ff9b commit ab2e76c
Show file tree
Hide file tree
Showing 48 changed files with 483 additions and 258 deletions.
3 changes: 1 addition & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions encodings/datetime-parts/src/stats.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use vortex_array::stats::{Stat, StatisticsVTable, StatsSet};
use vortex_array::ArrayLen;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::{DateTimePartsArray, DateTimePartsEncoding};

impl StatisticsVTable<DateTimePartsArray> for DateTimePartsEncoding {
fn compute_statistics(&self, array: &DateTimePartsArray, stat: Stat) -> VortexResult<StatsSet> {
let maybe_stat = match stat {
Stat::NullCount => Some(Scalar::from(array.validity().null_count(array.len())?)),
Stat::NullCount => Some(ScalarValue::from(array.validity().null_count(array.len())?)),
_ => None,
};

Expand Down
7 changes: 4 additions & 3 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ pub fn for_compress(array: PrimitiveArray) -> VortexResult<FoRArray> {
.compute(Stat::Min)
.ok_or_else(|| vortex_err!("Min stat not found"))?;

let nullability = array.dtype().nullability();
let dtype = array.dtype().clone();
let nullability = dtype.nullability();
let encoded = match_each_integer_ptype!(array.ptype(), |$T| {
if shift == <$T>::PTYPE.bit_width() as u8 {
assert_eq!(min, Scalar::zero::<$T>(array.dtype().nullability()));
assert_eq!(min, Scalar::zero::<$T>(array.dtype().nullability()).into_value());
encoded_zero::<$T>(array.validity().to_logical(array.len()), nullability)
.vortex_expect("Failed to encode all zeroes")
} else {
Expand All @@ -33,7 +34,7 @@ pub fn for_compress(array: PrimitiveArray) -> VortexResult<FoRArray> {
.into_array()
}
});
FoRArray::try_new(encoded, min, shift)
FoRArray::try_new(encoded, Scalar::new(dtype, min), shift)
}

fn encoded_zero<T: NativePType>(
Expand Down
8 changes: 4 additions & 4 deletions encodings/runend/src/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType as _, ArrayLen as _, IntoArrayVariant as _};
use vortex_dtype::{match_each_unsigned_integer_ptype, DType, NativePType};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::{RunEndArray, RunEndEncoding};

impl StatisticsVTable<RunEndArray> for RunEndEncoding {
fn compute_statistics(&self, array: &RunEndArray, stat: Stat) -> VortexResult<StatsSet> {
let maybe_stat = match stat {
Stat::Min | Stat::Max => array.values().statistics().compute(stat),
Stat::IsSorted => Some(Scalar::from(
Stat::IsSorted => Some(ScalarValue::from(
array
.values()
.statistics()
Expand All @@ -25,10 +25,10 @@ impl StatisticsVTable<RunEndArray> for RunEndEncoding {
&& array.logical_validity().all_valid(),
)),
Stat::TrueCount => match array.dtype() {
DType::Bool(_) => Some(Scalar::from(array.true_count()?)),
DType::Bool(_) => Some(ScalarValue::from(array.true_count()?)),
_ => None,
},
Stat::NullCount => Some(Scalar::from(array.null_count()?)),
Stat::NullCount => Some(ScalarValue::from(array.null_count()?)),
_ => None,
};

Expand Down
13 changes: 5 additions & 8 deletions encodings/zigzag/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use vortex_array::{
};
use vortex_dtype::{DType, PType};
use vortex_error::{vortex_bail, vortex_err, vortex_panic, VortexExpect as _, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;
use zigzag::ZigZag as ExternalZigZag;

use crate::compress::zigzag_encode;
Expand Down Expand Up @@ -98,15 +98,12 @@ impl StatisticsVTable<ZigZagArray> for ZigZagEncoding {
stats.set(stat, val);
}
} else if matches!(stat, Stat::Min | Stat::Max) {
let encoded_max = array
.encoded()
.statistics()
.compute_as_cast::<u64>(Stat::Max);
let encoded_max = array.encoded().statistics().compute_as::<u64>(Stat::Max);
if let Some(val) = encoded_max {
// the max of the encoded array is the element with the highest absolute value (so either min if negative, or max if positive)
let decoded = <i64 as ExternalZigZag>::decode(val);
let decoded_stat = if decoded < 0 { Stat::Min } else { Stat::Max };
stats.set(decoded_stat, Scalar::from(decoded).cast(array.dtype())?);
stats.set(decoded_stat, ScalarValue::from(decoded));
}
}

Expand Down Expand Up @@ -140,8 +137,8 @@ mod test {

let sliced = ZigZagArray::try_from(slice(zigzag, 0, 2).unwrap()).unwrap();
assert_eq!(
scalar_at(&sliced, sliced.len() - 1).unwrap(),
Scalar::from(-5i32)
scalar_at(&sliced, sliced.len() - 1).unwrap().into_value(),
ScalarValue::from(-5i32)
);
for stat in [Stat::Min, Stat::NullCount, Stat::IsConstant] {
let value = sliced.statistics().compute(stat);
Expand Down
6 changes: 3 additions & 3 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow_array::BooleanArray;
use arrow_buffer::{BooleanBufferBuilder, MutableBuffer};
use vortex_buffer::{Alignment, ByteBuffer};
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, VortexError, VortexExpect as _, VortexResult};
use vortex_error::{vortex_bail, VortexExpect as _, VortexResult};

use crate::encoding::ids;
use crate::stats::StatsSet;
Expand All @@ -13,8 +13,8 @@ use crate::validity::{LogicalValidity, Validity, ValidityMetadata, ValidityVTabl
use crate::variants::{BoolArrayTrait, VariantsVTable};
use crate::visitor::{ArrayVisitor, VisitorVTable};
use crate::{
impl_encoding, ArrayData, ArrayLen, Canonical, DeserializeMetadata, IntoArrayData,
IntoCanonical, RkyvMetadata,
impl_encoding, ArrayLen, Canonical, DeserializeMetadata, IntoArrayData, IntoCanonical,
RkyvMetadata,
};

pub mod compute;
Expand Down
5 changes: 1 addition & 4 deletions vortex-array/src/array/chunked/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
use std::fmt::{Debug, Display};

use futures_util::stream;
use rkyv::{access, to_bytes};
use serde::{Deserialize, Serialize};
use vortex_buffer::BufferMut;
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::{
vortex_bail, vortex_panic, VortexError, VortexExpect as _, VortexResult, VortexUnwrap,
};
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult, VortexUnwrap};

use crate::array::primitive::PrimitiveArray;
use crate::compute::{scalar_at, search_sorted_usize, SearchSortedSide};
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/constant/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ mod tests {
let canonical = const_array.into_canonical().unwrap();
let canonical_stats = canonical.statistics().to_set();

assert_eq!(canonical_stats, StatsSet::constant(&scalar, 4));
assert_eq!(canonical_stats, StatsSet::constant(scalar, 4));
assert_eq!(canonical_stats, stats);
}
}
9 changes: 4 additions & 5 deletions vortex-array/src/array/constant/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::fmt::Display;
use std::num::IntErrorKind::Empty;

use serde::{Deserialize, Serialize};
use vortex_error::{VortexExpect, VortexResult};
Expand All @@ -25,7 +24,7 @@ impl ConstantArray {
S: Into<Scalar>,
{
let scalar = scalar.into();
let stats = StatsSet::constant(&scalar, length);
let stats = StatsSet::constant(scalar.clone(), length);
let (dtype, scalar_value) = scalar.into_parts();

// Serialize the scalar_value into a FlatBuffer
Expand All @@ -44,13 +43,13 @@ impl ConstantArray {

/// Returns the [`Scalar`] value of this constant array.
pub fn scalar(&self) -> Scalar {
let value = ScalarValue::from_flexbytes(
let sv = ScalarValue::from_flexbytes(
self.as_ref()
.byte_buffer(0)
.vortex_expect("Missing scalar value buffer"),
)
.vortex_expect("Failed to deserialize scalar value");
Scalar::new(self.dtype().clone(), value)
Scalar::new(self.dtype().clone(), sv)
}
}

Expand All @@ -71,7 +70,7 @@ impl ValidityVTable<ConstantArray> for ConstantEncoding {

impl StatisticsVTable<ConstantArray> for ConstantEncoding {
fn compute_statistics(&self, array: &ConstantArray, _stat: Stat) -> VortexResult<StatsSet> {
Ok(StatsSet::constant(&array.scalar(), array.len()))
Ok(StatsSet::constant(array.scalar(), array.len()))
}
}

Expand Down
30 changes: 6 additions & 24 deletions vortex-array/src/array/extension/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::fmt::{Debug, Display};
use std::sync::Arc;

use enum_iterator::all;
use serde::{Deserialize, Serialize};
use vortex_buffer::ByteBuffer;
use vortex_dtype::{DType, ExtDType, ExtID};
use vortex_error::{VortexExpect as _, VortexResult};

Expand Down Expand Up @@ -34,7 +32,7 @@ impl ExtensionArray {
EmptyMetadata,
None,
Some([storage].into()),
Default::default(),
StatsSet::default(),
)
.vortex_expect("Invalid ExtensionArray")
}
Expand Down Expand Up @@ -93,25 +91,15 @@ impl VisitorVTable<ExtensionArray> for ExtensionEncoding {

impl StatisticsVTable<ExtensionArray> for ExtensionEncoding {
fn compute_statistics(&self, array: &ExtensionArray, stat: Stat) -> VortexResult<StatsSet> {
let mut stats = array.storage().statistics().compute_all(&[stat])?;

// for e.g., min/max, we want to cast to the extension array's dtype
// for other stats, we don't need to change anything
for stat in all::<Stat>().filter(|s| s.has_same_dtype_as_array()) {
if let Some(value) = stats.get(stat) {
stats.set(stat, value.cast(array.dtype())?);
}
}

Ok(stats)
array.storage().statistics().compute_all(&[stat])
}
}

#[cfg(test)]
mod tests {
use vortex_buffer::buffer;
use vortex_dtype::PType;
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use super::*;
use crate::IntoArrayData;
Expand All @@ -123,7 +111,7 @@ mod tests {
DType::from(PType::I64).into(),
None,
));
let array = ExtensionArray::new(ext_dtype.clone(), buffer![1i64, 2, 3, 4, 5].into_array());
let array = ExtensionArray::new(ext_dtype, buffer![1i64, 2, 3, 4, 5].into_array());

let stats = array
.statistics()
Expand All @@ -136,14 +124,8 @@ mod tests {
num_stats
);

assert_eq!(
stats.get(Stat::Min),
Some(&Scalar::extension(ext_dtype.clone(), Scalar::from(1_i64)))
);
assert_eq!(
stats.get(Stat::Max),
Some(&Scalar::extension(ext_dtype, Scalar::from(5_i64)))
);
assert_eq!(stats.get(Stat::Min), Some(&ScalarValue::from(1i64)));
assert_eq!(stats.get(Stat::Max), Some(&ScalarValue::from(5_i64)));
assert_eq!(stats.get(Stat::NullCount), Some(&0u64.into()));
}
}
18 changes: 9 additions & 9 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use num_traits::PrimInt;
use vortex_dtype::half::f16;
use vortex_dtype::{match_each_native_ptype, DType, NativePType, Nullability};
use vortex_error::{vortex_panic, VortexResult};
use vortex_scalar::Scalar;
use vortex_scalar::ScalarValue;

use crate::array::primitive::PrimitiveArray;
use crate::array::PrimitiveEncoding;
Expand All @@ -18,9 +18,9 @@ use crate::validity::{ArrayValidity, LogicalValidity};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayDType, IntoArrayVariant};

trait PStatsType: NativePType + Into<Scalar> + BitWidth {}
trait PStatsType: NativePType + Into<ScalarValue> + BitWidth {}

impl<T: NativePType + Into<Scalar> + BitWidth> PStatsType for T {}
impl<T: NativePType + Into<ScalarValue> + BitWidth> PStatsType for T {}

impl StatisticsVTable<PrimitiveArray> for PrimitiveEncoding {
fn compute_statistics(&self, array: &PrimitiveArray, stat: Stat) -> VortexResult<StatsSet> {
Expand All @@ -43,10 +43,10 @@ impl StatisticsVTable<PrimitiveArray> for PrimitiveEncoding {
})?;

if let Some(min) = stats.get(Stat::Min) {
stats.set(Stat::Min, min.cast(array.dtype())?);
stats.set(Stat::Min, min.clone());
}
if let Some(max) = stats.get(Stat::Max) {
stats.set(Stat::Max, max.cast(array.dtype())?);
stats.set(Stat::Max, max.clone());
}
Ok(stats)
}
Expand Down Expand Up @@ -169,7 +169,7 @@ fn compute_min_max<T: PStatsType>(
match iter.minmax_by(|a, b| a.total_compare(*b)) {
MinMaxResult::NoElements => StatsSet::default(),
MinMaxResult::OneElement(x) => {
let scalar: Scalar = x.into();
let scalar = x.into();
StatsSet::new_unchecked(vec![
(Stat::Min, scalar.clone()),
(Stat::Max, scalar),
Expand Down Expand Up @@ -334,7 +334,7 @@ impl<T: PStatsType> BitWidthAccumulator<T> {

#[cfg(test)]
mod test {
use vortex_scalar::Scalar;
use vortex_scalar::{Scalar, ScalarValue};

use crate::array::primitive::PrimitiveArray;
use crate::stats::{ArrayStatistics, Stat};
Expand Down Expand Up @@ -399,8 +399,8 @@ mod test {
#[test]
fn all_null() {
let arr = PrimitiveArray::from_option_iter([Option::<i32>::None, None, None]);
let min: Option<Scalar> = arr.statistics().compute(Stat::Min);
let max: Option<Scalar> = arr.statistics().compute(Stat::Max);
let min = arr.statistics().compute(Stat::Min);
let max = arr.statistics().compute(Stat::Max);
assert_eq!(min, None);
assert_eq!(max, None);
}
Expand Down
13 changes: 5 additions & 8 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::fmt::{Debug, Display};

use rkyv::rancor::{Error, Failure};
use rkyv::{access, from_bytes, to_bytes, Deserialize};
use vortex_error::{
vortex_bail, vortex_err, vortex_panic, VortexError, VortexExpect as _, VortexResult,
};
use rkyv::Deserialize;
use vortex_error::{vortex_bail, vortex_panic, VortexExpect as _, VortexResult};
use vortex_scalar::{Scalar, ScalarValue};

use crate::array::constant::ConstantArray;
Expand Down Expand Up @@ -150,13 +147,13 @@ impl SparseArray {

#[inline]
pub fn fill_scalar(&self) -> Scalar {
let fill_value = ScalarValue::from_flexbytes(
let sv = ScalarValue::from_flexbytes(
self.as_ref()
.byte_buffer(0)
.vortex_expect("Missing fill value buffer"),
)
.vortex_expect("Failed to deserialize fill value");
Scalar::new(self.dtype().clone(), fill_value)
Scalar::new(self.dtype().clone(), sv)
}
}

Expand All @@ -180,7 +177,7 @@ impl StatisticsVTable<SparseArray> for SparseEncoding {
let fill_stats = if array.fill_scalar().is_null() {
StatsSet::nulls(fill_len, array.dtype())
} else {
StatsSet::constant(&array.fill_scalar(), fill_len)
StatsSet::constant(array.fill_scalar(), fill_len)
};

if values.is_empty() {
Expand Down
Loading

0 comments on commit ab2e76c

Please sign in to comment.