Skip to content

Commit 01d51d8

Browse files
committed
Use a custom interleave strategy that takes the case evaluation logic into account
1 parent 7e21a19 commit 01d51d8

File tree

1 file changed

+121
-65
lines changed
  • datafusion/physical-expr/src/expressions

1 file changed

+121
-65
lines changed

datafusion/physical-expr/src/expressions/case.rs

Lines changed: 121 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ use crate::PhysicalExpr;
2121
use arrow::array::*;
2222
use arrow::compute::kernels::zip::zip;
2323
use arrow::compute::{
24-
interleave, is_null, not, nullif, prep_null_mask_filter, FilterBuilder,
25-
FilterPredicate,
24+
is_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate,
2625
};
2726
use arrow::datatypes::{DataType, Schema, UInt32Type};
2827
use arrow::error::ArrowError;
@@ -40,7 +39,7 @@ use std::{any::Any, sync::Arc};
4039
type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>);
4140

4241
#[derive(Debug, Hash, PartialEq, Eq)]
43-
enum EvalMethod {
42+
pub enum EvalMethod {
4443
/// CASE WHEN condition THEN result
4544
/// [WHEN ...]
4645
/// [ELSE result]
@@ -96,7 +95,7 @@ pub struct CaseExpr {
9695
/// Optional "else" expression
9796
else_expr: Option<Arc<dyn PhysicalExpr>>,
9897
/// Evaluation method to use
99-
eval_method: EvalMethod,
98+
pub eval_method: EvalMethod,
10099
}
101100

102101
impl std::fmt::Display for CaseExpr {
@@ -159,18 +158,27 @@ fn filter_array(
159158
filter.filter(array)
160159
}
161160

162-
struct InterleaveBuilder {
163-
indices: Vec<(usize, usize)>,
164-
arrays: Vec<ArrayRef>,
161+
struct ResultBuilder {
162+
data_type: DataType,
163+
// A Vec of partial results that should be merged. `partial_result_indices` contains
164+
// indexes into this vec.
165+
partial_results: Vec<ArrayData>,
166+
// Indicates per result row from which array in `partial_results` a value should be taken.
167+
// The indexes in this array are offset by +1. The special value 0 indicates null values.
168+
partial_result_indices: Vec<usize>,
169+
// An optional result that is the covering result for all rows.
170+
// This is used as an optimisation to avoid the cost of merging when all rows
171+
// evaluate to the same case branch.
172+
covering_result: Option<ColumnarValue>,
165173
}
166174

167-
impl InterleaveBuilder {
175+
impl ResultBuilder {
168176
fn new(data_type: &DataType, capacity: usize) -> Self {
169-
// By settings indices to (0, 0) every entry points to the single
170-
// null value in the first array.
171177
Self {
172-
indices: vec![(0, 0); capacity],
173-
arrays: vec![new_null_array(data_type, 1)],
178+
data_type: data_type.clone(),
179+
partial_result_indices: vec![0; capacity],
180+
partial_results: vec![],
181+
covering_result: None,
174182
}
175183
}
176184

@@ -183,56 +191,106 @@ impl InterleaveBuilder {
183191
/// If `value` is an array, the values from the array and the indices from `rows` will be
184192
/// processed pairwise.
185193
fn add_result(&mut self, rows: &ArrayRef, value: ColumnarValue) -> Result<()> {
186-
let array_index = self.arrays.len();
187194
match value {
188195
ColumnarValue::Array(a) => {
189196
assert_eq!(a.len(), rows.len());
190-
191-
self.arrays.push(a);
192-
for (array_ix, row_ix) in rows
193-
.as_primitive::<UInt32Type>()
194-
.values()
195-
.iter()
196-
.enumerate()
197-
{
198-
self.indices[*row_ix as usize] = (array_index, array_ix);
197+
if rows.len() == self.partial_result_indices.len() {
198+
self.set_covering_result(ColumnarValue::Array(a));
199+
} else {
200+
self.add_partial_result(rows, a.to_data());
199201
}
200202
}
201203
ColumnarValue::Scalar(s) => {
202-
self.arrays.push(s.to_array()?);
203-
for row_ix in rows.as_primitive::<UInt32Type>().values().iter() {
204-
self.indices[*row_ix as usize] = (array_index, 0);
204+
if rows.len() == self.partial_result_indices.len() {
205+
self.set_covering_result(ColumnarValue::Scalar(s));
206+
} else {
207+
self.add_partial_result(
208+
rows,
209+
s.to_array_of_size(rows.len())?.to_data(),
210+
);
205211
}
206212
}
207213
}
208214
Ok(())
209215
}
210216

211-
fn finish(mut self) -> Result<ColumnarValue> {
212-
if self.arrays.len() == 1 {
213-
// The first array is always a single null value.
214-
if self.indices.len() == 1 {
215-
// If there's only a single row, reuse the array
216-
Ok(ColumnarValue::Array(self.arrays.remove(0)))
217-
} else {
218-
// Otherwise make a new null array with the correct type and length
219-
Ok(ColumnarValue::Array(new_null_array(
220-
self.arrays[0].data_type(),
221-
self.indices.len(),
222-
)))
217+
fn add_partial_result(&mut self, rows: &ArrayRef, data: ArrayData) {
218+
assert!(self.covering_result.is_none());
219+
220+
self.partial_results.push(data);
221+
let array_index = self.partial_results.len();
222+
223+
for row_ix in rows.as_primitive::<UInt32Type>().values().iter() {
224+
self.partial_result_indices[*row_ix as usize] = array_index;
225+
}
226+
}
227+
228+
fn set_covering_result(&mut self, value: ColumnarValue) {
229+
assert!(self.partial_results.is_empty());
230+
self.covering_result = Some(value);
231+
}
232+
233+
fn finish(self) -> Result<ColumnarValue> {
234+
match self.covering_result {
235+
Some(v) => {
236+
// If we have a covering result, we can just return it.
237+
Ok(v)
223238
}
224-
} else if self.arrays.len() == 2
225-
&& !self.indices.iter().any(|(array_ix, _)| *array_ix == 0)
226-
&& self.arrays[1].len() == self.indices.len()
227-
{
228-
// There's only a single non-null array and no references to the null array.
229-
// We can take a shortcut and return the non-null array directly.
230-
Ok(ColumnarValue::Array(self.arrays.remove(1)))
231-
} else {
232-
// Interleave arrays
233-
let array_refs = self.arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
234-
let interleaved_result = interleave(&array_refs, &self.indices)?;
235-
Ok(ColumnarValue::Array(interleaved_result))
239+
None => match self.partial_results.len() {
240+
0 => {
241+
// No covering result and no partial results.
242+
// This can happen for case expressions with no else branch where no rows
243+
// matched.
244+
Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
245+
&self.data_type,
246+
)?))
247+
}
248+
n => {
249+
// There are n partial results.
250+
// Merge into a single array.
251+
252+
let data_refs = self.partial_results.iter().collect();
253+
let mut mutable = MutableArrayData::new(
254+
data_refs,
255+
true,
256+
self.partial_result_indices.len(),
257+
);
258+
259+
// take_offsets keeps track of how many values have been taken from each array.
260+
let mut take_offsets = vec![0; n + 1];
261+
262+
let mut row_ix = 0;
263+
loop {
264+
let array_ix = self.partial_result_indices[row_ix];
265+
row_ix += 1;
266+
267+
// Determine the length of the slice to take.
268+
let start_offset = take_offsets[array_ix];
269+
let end_offset = start_offset + 1;
270+
while row_ix < self.partial_result_indices.len()
271+
&& self.partial_result_indices[row_ix] == array_ix
272+
{
273+
row_ix += 1;
274+
}
275+
276+
// Extend the result array either with nulls or with values from the array.
277+
if array_ix == 0 {
278+
mutable.extend_nulls(end_offset - start_offset);
279+
} else {
280+
mutable.extend(array_ix - 1, start_offset, end_offset);
281+
}
282+
283+
// Update the take_offsets array.
284+
take_offsets[array_ix] = end_offset;
285+
286+
if row_ix == self.partial_result_indices.len() {
287+
break;
288+
}
289+
}
290+
291+
Ok(ColumnarValue::Array(make_array(mutable.freeze())))
292+
}
293+
},
236294
}
237295
}
238296
}
@@ -313,12 +371,11 @@ impl CaseExpr {
313371
let optimize_filters = batch.num_columns() > 1;
314372

315373
let return_type = self.data_type(&batch.schema())?;
316-
let mut interleave_builder =
317-
InterleaveBuilder::new(&return_type, batch.num_rows());
374+
let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows());
318375

319376
// `remainder_rows` contains the indices of the rows that need to be evaluated
320377
let mut remainder_rows: ArrayRef =
321-
Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32));
378+
Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32));
322379
// `remainder_batch` contains the rows themselves that need to be evaluated
323380
let mut remainder_batch = Cow::Borrowed(batch);
324381

@@ -337,20 +394,20 @@ impl CaseExpr {
337394
let base_nulls = is_null(base_value.as_ref())?;
338395
if base_nulls.true_count() > 0 {
339396
// If there is an else expression, use that as the default value for the null rows
340-
// Otherwise the default `null` value from the eInterleaveBuilder will be used.
397+
// Otherwise the default `null` value from the result builder will be used.
341398
if let Some(e) = self.else_expr() {
342399
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
343400

344401
let nulls_filter = create_filter(&base_nulls, optimize_filters);
345402
let nulls_batch = filter_record_batch(&remainder_batch, &nulls_filter)?;
346403
let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?;
347404
let nulls_value = expr.evaluate(&nulls_batch)?;
348-
interleave_builder.add_result(&nulls_rows, nulls_value)?;
405+
result_builder.add_result(&nulls_rows, nulls_value)?;
349406
}
350407

351408
// All base values were null, so we can return early
352409
if base_nulls.true_count() == remainder_batch.num_rows() {
353-
return interleave_builder.finish();
410+
return result_builder.finish();
354411
}
355412

356413
// Remove the null rows from the remainder batch
@@ -397,14 +454,14 @@ impl CaseExpr {
397454

398455
let then_expression = &self.when_then_expr[i].1;
399456
let then_value = then_expression.evaluate(&then_batch)?;
400-
interleave_builder.add_result(&then_rows, then_value)?;
457+
result_builder.add_result(&then_rows, then_value)?;
401458

402459
// If the 'when' predicate matched all remaining row, there's nothing left to do so
403460
// we can return early
404461
if remainder_batch.num_rows() == when_match_count
405462
|| (self.else_expr.is_none() && i == self.when_then_expr.len() - 1)
406463
{
407-
return interleave_builder.finish();
464+
return result_builder.finish();
408465
}
409466

410467
// Prepare the next when branch (or the else branch)
@@ -422,10 +479,10 @@ impl CaseExpr {
422479
// keep `else_expr`'s data type and return type consistent
423480
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
424481
let else_value = expr.evaluate(&remainder_batch)?;
425-
interleave_builder.add_result(&remainder_rows, else_value)?;
482+
result_builder.add_result(&remainder_rows, else_value)?;
426483
}
427484

428-
interleave_builder.finish()
485+
result_builder.finish()
429486
}
430487

431488
/// This function evaluates the form of CASE where each WHEN expression is a boolean
@@ -439,8 +496,7 @@ impl CaseExpr {
439496
let optimize_filters = batch.num_columns() > 1;
440497

441498
let return_type = self.data_type(&batch.schema())?;
442-
let mut interleave_builder =
443-
InterleaveBuilder::new(&return_type, batch.num_rows());
499+
let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows());
444500

445501
// `remainder_rows` contains the indices of the rows that need to be evaluated
446502
let mut remainder_rows: ArrayRef =
@@ -480,14 +536,14 @@ impl CaseExpr {
480536

481537
let then_expression = &self.when_then_expr[i].1;
482538
let then_value = then_expression.evaluate(&then_batch)?;
483-
interleave_builder.add_result(&then_rows, then_value)?;
539+
result_builder.add_result(&then_rows, then_value)?;
484540

485541
// If the 'when' predicate matched all remaining row, there's nothing left to do so
486542
// we can return early
487543
if remainder_batch.num_rows() == when_match_count
488544
|| (self.else_expr.is_none() && i == self.when_then_expr.len() - 1)
489545
{
490-
return interleave_builder.finish();
546+
return result_builder.finish();
491547
}
492548

493549
// Prepare the next when branch (or the else branch)
@@ -504,10 +560,10 @@ impl CaseExpr {
504560
// keep `else_expr`'s data type and return type consistent
505561
let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?;
506562
let else_value = expr.evaluate(&remainder_batch)?;
507-
interleave_builder.add_result(&remainder_rows, else_value)?;
563+
result_builder.add_result(&remainder_rows, else_value)?;
508564
}
509565

510-
interleave_builder.finish()
566+
result_builder.finish()
511567
}
512568

513569
/// This function evaluates the specialized case of:

0 commit comments

Comments
 (0)