Skip to content

Commit d05dcc3

Browse files
authored
fix: Filter out nulls properly in approx_percentile_cont_with_weight (#17780)
1 parent 320dd91 commit d05dcc3

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

datafusion/functions-aggregate/src/approx_percentile_cont.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,19 +384,23 @@ impl ApproxPercentileAccumulator {
384384
}
385385
}
386386

387-
// public for approx_percentile_cont_with_weight
387+
// pub(crate) for approx_percentile_cont_with_weight
388388
pub(crate) fn max_size(&self) -> usize {
389389
self.digest.max_size()
390390
}
391391

392-
// public for approx_percentile_cont_with_weight
393-
pub fn merge_digests(&mut self, digests: &[TDigest]) {
392+
// pub(crate) for approx_percentile_cont_with_weight
393+
pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
394394
let digests = digests.iter().chain(std::iter::once(&self.digest));
395395
self.digest = TDigest::merge_digests(digests)
396396
}
397397

398-
// public for approx_percentile_cont_with_weight
399-
pub fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
398+
// pub(crate) for approx_percentile_cont_with_weight
399+
pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
400+
debug_assert!(
401+
values.null_count() == 0,
402+
"convert_to_float assumes nulls have already been filtered out"
403+
);
400404
match values.data_type() {
401405
DataType::Float64 => {
402406
let array = downcast_value!(values, Float64Array);
@@ -493,7 +497,7 @@ impl Accumulator for ApproxPercentileAccumulator {
493497
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
494498
// Remove any nulls before computing the percentile
495499
let mut values = Arc::clone(&values[0]);
496-
if values.nulls().is_some() {
500+
if values.null_count() > 0 {
497501
values = filter(&values, &is_not_null(&values)?)?;
498502
}
499503
let sorted_values = &arrow::compute::sort(&values, None)?;

datafusion/functions-aggregate/src/approx_percentile_cont_with_weight.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::hash::Hash;
2121
use std::mem::size_of_val;
2222
use std::sync::Arc;
2323

24+
use arrow::compute::{and, filter, is_not_null};
2425
use arrow::datatypes::FieldRef;
2526
use arrow::{array::ArrayRef, datatypes::DataType};
2627
use datafusion_common::ScalarValue;
@@ -268,15 +269,37 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
268269
}
269270

270271
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
271-
let means = &values[0];
272-
let weights = &values[1];
272+
let mut means = Arc::clone(&values[0]);
273+
let mut weights = Arc::clone(&values[1]);
274+
// If nulls are present in either array, need to filter those rows out in both arrays
275+
match (means.null_count() > 0, weights.null_count() > 0) {
276+
// Both have nulls
277+
(true, true) => {
278+
let predicate = and(&is_not_null(&means)?, &is_not_null(&weights)?)?;
279+
means = filter(&means, &predicate)?;
280+
weights = filter(&weights, &predicate)?;
281+
}
282+
// Only one has nulls
283+
(false, true) => {
284+
let predicate = &is_not_null(&weights)?;
285+
means = filter(&means, predicate)?;
286+
weights = filter(&weights, predicate)?;
287+
}
288+
(true, false) => {
289+
let predicate = &is_not_null(&means)?;
290+
means = filter(&means, predicate)?;
291+
weights = filter(&weights, predicate)?;
292+
}
293+
// No nulls
294+
(false, false) => {}
295+
}
273296
debug_assert_eq!(
274297
means.len(),
275298
weights.len(),
276299
"invalid number of values in means and weights"
277300
);
278-
let means_f64 = ApproxPercentileAccumulator::convert_to_float(means)?;
279-
let weights_f64 = ApproxPercentileAccumulator::convert_to_float(weights)?;
301+
let means_f64 = ApproxPercentileAccumulator::convert_to_float(&means)?;
302+
let weights_f64 = ApproxPercentileAccumulator::convert_to_float(&weights)?;
280303
let mut digests: Vec<TDigest> = vec![];
281304
for (mean, weight) in means_f64.iter().zip(weights_f64.iter()) {
282305
digests.push(TDigest::new_with_centroid(

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,6 +1540,19 @@ SELECT APPROX_PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(
15401540
----
15411541
NULL
15421542

1543+
# percentile_cont_with_weight_with_nulls
1544+
query I
1545+
SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT(w, 0.5) WITHIN GROUP (ORDER BY v)
1546+
FROM (VALUES (1, 1), (2, 1), (3, 1), (4, NULL), (NULL, 1), (NULL, NULL)) as t (v, w);
1547+
----
1548+
2
1549+
1550+
# percentile_cont_with_weight_nulls_only
1551+
query I
1552+
SELECT APPROX_PERCENTILE_CONT_WITH_WEIGHT(1, 0.5) WITHIN GROUP (ORDER BY v) FROM (VALUES (CAST(NULL as INT))) as t (v);
1553+
----
1554+
NULL
1555+
15431556
#
15441557
# percentile_cont edge cases
15451558
#

0 commit comments

Comments
 (0)