Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: New quantile interpolation method & QUANTILE_DISC function in SQL #19139

Merged
merged 14 commits into from
Oct 16, 2024
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/legacy/kernels/rolling/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,5 @@ pub struct RollingVarParams {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct RollingQuantileParams {
pub prob: f64,
pub interpol: QuantileInterpolOptions,
pub method: QuantileMethod,
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@ where

#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum QuantileInterpolOptions {
pub enum QuantileMethod {
#[default]
Nearest,
Lower,
Higher,
Midpoint,
Linear,
Equiprobable,
}

#[deprecated(note = "use QuantileMethod instead")]
pub type QuantileInterpolOptions = QuantileMethod;

pub(super) fn rolling_apply_weights<T, Fo, Fa>(
values: &[T],
window_size: usize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use num_traits::ToPrimitive;
use polars_error::polars_ensure;
use polars_utils::slice::GetSaferUnchecked;

use super::QuantileInterpolOptions::*;
use super::QuantileMethod::*;
use super::*;

pub struct QuantileWindow<'a, T: NativeType> {
sorted: SortedBuf<'a, T>,
prob: f64,
interpol: QuantileInterpolOptions,
method: QuantileMethod,
}

impl<
Expand All @@ -34,15 +34,15 @@ impl<
Self {
sorted: SortedBuf::new(slice, start, end),
prob: params.prob,
interpol: params.interpol,
method: params.method,
}
}

unsafe fn update(&mut self, start: usize, end: usize) -> Option<T> {
let vals = self.sorted.update(start, end);
let length = vals.len();

let idx = match self.interpol {
let idx = match self.method {
Linear => {
// Maybe add a fast path for median case? They could branch depending on odd/even.
let length_f = length as f64;
Expand Down Expand Up @@ -92,6 +92,7 @@ impl<
let idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
std::cmp::min(idx, length - 1)
},
Equiprobable => ((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize,
};

// SAFETY:
Expand Down Expand Up @@ -134,7 +135,7 @@ where
unreachable!("expected Quantile params");
};
let out = super::quantile_filter::rolling_quantile::<_, Vec<_>>(
params.interpol,
params.method,
min_periods,
window_size,
values,
Expand Down Expand Up @@ -170,7 +171,7 @@ where
Ok(rolling_apply_weighted_quantile(
values,
params.prob,
params.interpol,
params.method,
window_size,
min_periods,
offset_fn,
Expand All @@ -182,7 +183,7 @@ where
}

#[inline]
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, interp: QuantileInterpolOptions) -> T
fn compute_wq<T>(buf: &[(T, f64)], p: f64, wsum: f64, method: QuantileMethod) -> T
where
T: Debug + NativeType + Mul<Output = T> + Sub<Output = T> + NumCast + ToPrimitive + Zero,
{
Expand All @@ -201,7 +202,7 @@ where
(s_old, v_old, vk) = (s, vk, v);
s += w;
}
match (h == s_old, interp) {
match (h == s_old, method) {
(true, _) => v_old, // If we hit the break exactly interpolation shouldn't matter
(_, Lower) => v_old,
(_, Higher) => vk,
Expand All @@ -212,6 +213,14 @@ where
vk
}
},
(_, Equiprobable) => {
let threshold = (wsum * p).ceil() - 1.0;
if s > threshold {
vk
} else {
v_old
}
},
(_, Midpoint) => (vk + v_old) * NumCast::from(0.5).unwrap(),
// This is seemingly the canonical way to do it.
(_, Linear) => {
Expand All @@ -224,7 +233,7 @@ where
fn rolling_apply_weighted_quantile<T, Fo>(
values: &[T],
p: f64,
interpolation: QuantileInterpolOptions,
method: QuantileMethod,
window_size: usize,
min_periods: usize,
det_offsets_fn: Fo,
Expand Down Expand Up @@ -252,7 +261,7 @@ where
.for_each(|(b, (i, w))| *b = (*values.get_unchecked(i + start), **w));
}
buf.sort_unstable_by(|&a, &b| a.0.tot_cmp(&b.0));
compute_wq(&buf, p, wsum, interpolation)
compute_wq(&buf, p, wsum, method)
})
.collect_trusted::<Vec<T>>();

Expand All @@ -273,7 +282,7 @@ mod test {
let values = &[1.0, 2.0, 3.0, 4.0];
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
interpol: Linear,
method: Linear,
}));
let out = rolling_quantile(values, 2, 2, false, None, med_pars.clone()).unwrap();
let out = out.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down Expand Up @@ -305,18 +314,19 @@ mod test {
fn test_rolling_quantile_limits() {
let values = &[1.0f64, 2.0, 3.0, 4.0];

let interpol_options = vec![
QuantileInterpolOptions::Lower,
QuantileInterpolOptions::Higher,
QuantileInterpolOptions::Nearest,
QuantileInterpolOptions::Midpoint,
QuantileInterpolOptions::Linear,
let methods = vec![
QuantileMethod::Lower,
QuantileMethod::Higher,
QuantileMethod::Nearest,
QuantileMethod::Midpoint,
QuantileMethod::Linear,
QuantileMethod::Equiprobable,
];

for interpol in interpol_options {
for method in methods {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
interpol,
method,
}));
let out1 = rolling_min(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand All @@ -328,7 +338,7 @@ mod test {

let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
interpol,
method,
}));
let out1 = rolling_max(values, 2, 2, false, None, None).unwrap();
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down
46 changes: 24 additions & 22 deletions crates/polars-arrow/src/legacy/kernels/rolling/nulls/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::array::MutablePrimitiveArray;
pub struct QuantileWindow<'a, T: NativeType + IsFloat + PartialOrd> {
sorted: SortedBufNulls<'a, T>,
prob: f64,
interpol: QuantileInterpolOptions,
method: QuantileMethod,
}

impl<
Expand Down Expand Up @@ -39,7 +39,7 @@ impl<
Self {
sorted: SortedBufNulls::new(slice, validity, start, end),
prob: params.prob,
interpol: params.interpol,
method: params.method,
}
}

Expand All @@ -53,29 +53,30 @@ impl<
let values = &values[null_count..];
let length = values.len();

let mut idx = match self.interpol {
QuantileInterpolOptions::Nearest => ((length as f64) * self.prob) as usize,
QuantileInterpolOptions::Lower
| QuantileInterpolOptions::Midpoint
| QuantileInterpolOptions::Linear => {
let mut idx = match self.method {
QuantileMethod::Nearest => ((length as f64) * self.prob) as usize,
QuantileMethod::Lower | QuantileMethod::Midpoint | QuantileMethod::Linear => {
((length as f64 - 1.0) * self.prob).floor() as usize
},
QuantileInterpolOptions::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
QuantileMethod::Higher => ((length as f64 - 1.0) * self.prob).ceil() as usize,
QuantileMethod::Equiprobable => {
((length as f64 * self.prob).ceil() - 1.0).max(0.0) as usize
},
};

idx = std::cmp::min(idx, length - 1);

// we can unwrap because we sliced of the nulls
match self.interpol {
QuantileInterpolOptions::Midpoint => {
match self.method {
QuantileMethod::Midpoint => {
let top_idx = ((length as f64 - 1.0) * self.prob).ceil() as usize;
Some(
(values.get_unchecked_release(idx).unwrap()
+ values.get_unchecked_release(top_idx).unwrap())
/ T::from::<f64>(2.0f64).unwrap(),
)
},
QuantileInterpolOptions::Linear => {
QuantileMethod::Linear => {
let float_idx = (length as f64 - 1.0) * self.prob;
let top_idx = f64::ceil(float_idx) as usize;

Expand Down Expand Up @@ -136,7 +137,7 @@ where
};

let out = super::quantile_filter::rolling_quantile::<_, MutablePrimitiveArray<_>>(
params.interpol,
params.method,
min_periods,
window_size,
arr.clone(),
Expand Down Expand Up @@ -171,7 +172,7 @@ mod test {
);
let med_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.5,
interpol: QuantileInterpolOptions::Linear,
method: QuantileMethod::Linear,
}));

let out = rolling_quantile(arr, 2, 2, false, None, med_pars.clone());
Expand Down Expand Up @@ -210,18 +211,19 @@ mod test {
Some(Bitmap::from(&[true, false, false, true, true])),
);

let interpol_options = vec![
QuantileInterpolOptions::Lower,
QuantileInterpolOptions::Higher,
QuantileInterpolOptions::Nearest,
QuantileInterpolOptions::Midpoint,
QuantileInterpolOptions::Linear,
let methods = vec![
QuantileMethod::Lower,
QuantileMethod::Higher,
QuantileMethod::Nearest,
QuantileMethod::Midpoint,
QuantileMethod::Linear,
QuantileMethod::Equiprobable,
];

for interpol in interpol_options {
for method in methods {
let min_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 0.0,
interpol,
method,
}));
let out1 = rolling_min(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand All @@ -233,7 +235,7 @@ mod test {

let max_pars = Some(RollingFnParams::Quantile(RollingQuantileParams {
prob: 1.0,
interpol,
method,
}));
let out1 = rolling_max(values, 2, 1, false, None, None);
let out1 = out1.as_any().downcast_ref::<PrimitiveArray<f64>>().unwrap();
Expand Down
Loading