Skip to content

Commit 206af12

Browse files
committed
Reduce record batch filtering
1 parent 3f0eaa3 commit 206af12

File tree

1 file changed

+75
-46
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+75
-46
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,25 @@
1717

1818
use crate::expressions::try_cast;
1919
use crate::PhysicalExpr;
20-
use std::borrow::Cow;
21-
use std::hash::Hash;
22-
use std::{any::Any, sync::Arc};
23-
2420
use arrow::array::*;
2521
use arrow::compute::kernels::zip::zip;
26-
use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter};
22+
use arrow::compute::{
23+
and, and_not, filter_record_batch, is_null, not, nullif, or, prep_null_mask_filter,
24+
};
2725
use arrow::datatypes::{DataType, Schema};
26+
use arrow::error::ArrowError;
2827
use datafusion_common::cast::as_boolean_array;
2928
use datafusion_common::{
3029
exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue,
3130
};
3231
use datafusion_expr::ColumnarValue;
32+
use std::borrow::Cow;
33+
use std::hash::Hash;
34+
use std::{any::Any, sync::Arc};
3335

3436
use super::{Column, Literal};
3537
use datafusion_physical_expr_common::datum::compare_with_eq;
38+
use datafusion_physical_expr_common::utils::scatter;
3639
use itertools::Itertools;
3740

3841
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
@@ -122,6 +125,24 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool {
122125
expr.as_any().is::<Column>()
123126
}
124127

128+
fn merge_result(
129+
current_value: &dyn Array,
130+
then_value: ColumnarValue,
131+
then_scatter: &BooleanArray,
132+
) -> std::result::Result<ArrayRef, ArrowError> {
133+
match then_value {
134+
ColumnarValue::Scalar(ScalarValue::Null) => nullif(current_value, then_scatter),
135+
ColumnarValue::Scalar(then_value) => {
136+
zip(then_scatter, &then_value.to_scalar()?, &current_value)
137+
}
138+
ColumnarValue::Array(then_value) => {
139+
// TODO this operation should probably be feasible in one pass
140+
let scattered_then = scatter(then_scatter, then_value.as_ref())?;
141+
zip(then_scatter, &scattered_then, &current_value)
142+
}
143+
}
144+
}
145+
125146
impl CaseExpr {
126147
/// Create a new CASE WHEN expression
127148
pub fn try_new(
@@ -286,64 +307,72 @@ impl CaseExpr {
286307

287308
// start with nulls as default output
288309
let mut current_value = new_null_array(&return_type, batch.num_rows());
289-
let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]);
290-
let mut remainder_count = batch.num_rows();
291-
for i in 0..self.when_then_expr.len() {
292-
// If there are no rows left to process, break out of the loop early
293-
if remainder_count == 0 {
294-
break;
295-
}
296310

311+
let mut remainder_scatter = BooleanArray::from(vec![true; batch.num_rows()]);
312+
let mut remainder_batch = Cow::Borrowed(batch);
313+
314+
for i in 0..self.when_then_expr.len() {
315+
// Evaluate the 'when' predicate for the remainder batch
316+
// This results in a boolean array with the same length as the remaining number of rows
297317
let when_predicate = &self.when_then_expr[i].0;
298-
let when_value = when_predicate.evaluate_selection(batch, &remainder)?;
299-
let when_value = when_value.into_array(batch.num_rows())?;
318+
let when_value = when_predicate.evaluate(&remainder_batch)?;
319+
let when_value = when_value.to_array(remainder_batch.num_rows())?;
300320
let when_value = as_boolean_array(&when_value).map_err(|_| {
301321
internal_datafusion_err!("WHEN expression did not return a BooleanArray")
302322
})?;
303-
// Treat 'NULL' as false value
304-
let when_value = match when_value.null_count() {
305-
0 => Cow::Borrowed(when_value),
306-
_ => Cow::Owned(prep_null_mask_filter(when_value)),
307-
};
308-
// Make sure we only consider rows that have not been matched yet
309-
let when_value = and(&when_value, &remainder)?;
310323

311-
// If the predicate did not match any rows, continue to the next branch immediately
324+
// If the 'when' predicate did not match any rows, continue to the next branch immediately
312325
let when_match_count = when_value.true_count();
313326
if when_match_count == 0 {
314327
continue;
315328
}
316329

330+
// Make sure 'NULL' is treated as false
331+
let when_value = match when_value.null_count() {
332+
0 => Cow::Borrowed(when_value),
333+
_ => Cow::Owned(prep_null_mask_filter(when_value)),
334+
};
335+
336+
// Filter the remainder batch based on the 'when' value
337+
// This results in a batch containing only the rows that need to be evaluated
338+
// for the current branch
339+
let then_batch = filter_record_batch(&remainder_batch, &when_value)?;
340+
341+
// Evaluate the then expression for the matching rows
317342
let then_expression = &self.when_then_expr[i].1;
318-
let then_value = then_expression.evaluate_selection(batch, &when_value)?;
343+
let then_value = then_expression.evaluate(&then_batch)?;
319344

320-
current_value = match then_value {
321-
ColumnarValue::Scalar(ScalarValue::Null) => {
322-
nullif(current_value.as_ref(), &when_value)?
323-
}
324-
ColumnarValue::Scalar(then_value) => {
325-
zip(&when_value, &then_value.to_scalar()?, &current_value)?
326-
}
327-
ColumnarValue::Array(then_value) => {
328-
zip(&when_value, &then_value, &current_value)?
329-
}
330-
};
345+
// Expand the 'when' match array using the 'remainder scatter' array
346+
// This results in a truthy boolean array than we can use to merge the
347+
// 'then' values with the `current_value` array.
348+
let then_merge = scatter(&remainder_scatter, when_value.as_ref())?;
349+
let then_merge = then_merge.as_boolean();
331350

332-
// Succeed tuples should be filtered out for short-circuit evaluation,
333-
// null values for the current when expr should be kept
334-
remainder = and_not(&remainder, &when_value)?;
335-
remainder_count -= when_match_count;
351+
// Merge the 'then' values with the `current_value` array
352+
current_value = merge_result(&current_value, then_value, then_merge)?;
353+
354+
// If the 'when' predicate matched all remaining row, there's nothing left to do so
355+
// we can return early
356+
if remainder_batch.num_rows() == when_match_count {
357+
return Ok(ColumnarValue::Array(current_value));
358+
}
359+
360+
// Clear the positions in 'remainder scatter' for which we just evaluated a value
361+
remainder_scatter = and_not(&remainder_scatter, then_merge)?;
362+
363+
// Finally, prepare the remainder batch for the next branch
364+
let next_selection = not(&when_value)?;
365+
remainder_batch =
366+
Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?);
336367
}
337368

369+
// If we reached this point, some rows were left unmatched.
370+
// Check if those need to be evaluated using the 'else' expression.
338371
if let Some(e) = self.else_expr() {
339-
if remainder_count > 0 {
340-
// keep `else_expr`'s data type and return type consistent
341-
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
342-
let else_ = expr
343-
.evaluate_selection(batch, &remainder)?
344-
.into_array(batch.num_rows())?;
345-
current_value = zip(&remainder, &else_, &current_value)?;
346-
}
372+
// keep `else_expr`'s data type and return type consistent
373+
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
374+
let else_value = expr.evaluate(&remainder_batch)?;
375+
current_value = merge_result(&current_value, else_value, &remainder_scatter)?;
347376
}
348377

349378
Ok(ColumnarValue::Array(current_value))

0 commit comments

Comments
 (0)