|
17 | 17 |
|
18 | 18 | use crate::expressions::try_cast; |
19 | 19 | use crate::PhysicalExpr; |
20 | | -use std::borrow::Cow; |
21 | | -use std::hash::Hash; |
22 | | -use std::{any::Any, sync::Arc}; |
23 | | - |
24 | 20 | use arrow::array::*; |
25 | 21 | use arrow::compute::kernels::zip::zip; |
26 | | -use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; |
| 22 | +use arrow::compute::{ |
| 23 | + and, and_not, filter_record_batch, is_null, not, nullif, or, prep_null_mask_filter, |
| 24 | +}; |
27 | 25 | use arrow::datatypes::{DataType, Schema}; |
| 26 | +use arrow::error::ArrowError; |
28 | 27 | use datafusion_common::cast::as_boolean_array; |
29 | 28 | use datafusion_common::{ |
30 | 29 | exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, |
31 | 30 | }; |
32 | 31 | use datafusion_expr::ColumnarValue; |
| 32 | +use std::borrow::Cow; |
| 33 | +use std::hash::Hash; |
| 34 | +use std::{any::Any, sync::Arc}; |
33 | 35 |
|
34 | 36 | use super::{Column, Literal}; |
35 | 37 | use datafusion_physical_expr_common::datum::compare_with_eq; |
| 38 | +use datafusion_physical_expr_common::utils::scatter; |
36 | 39 | use itertools::Itertools; |
37 | 40 |
|
38 | 41 | type WhenThen = (Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>); |
@@ -122,6 +125,24 @@ fn is_cheap_and_infallible(expr: &Arc<dyn PhysicalExpr>) -> bool { |
122 | 125 | expr.as_any().is::<Column>() |
123 | 126 | } |
124 | 127 |
|
| 128 | +fn merge_result( |
| 129 | + current_value: &dyn Array, |
| 130 | + then_value: ColumnarValue, |
| 131 | + then_scatter: &BooleanArray, |
| 132 | +) -> std::result::Result<ArrayRef, ArrowError> { |
| 133 | + match then_value { |
| 134 | + ColumnarValue::Scalar(ScalarValue::Null) => nullif(current_value, then_scatter), |
| 135 | + ColumnarValue::Scalar(then_value) => { |
| 136 | + zip(then_scatter, &then_value.to_scalar()?, ¤t_value) |
| 137 | + } |
| 138 | + ColumnarValue::Array(then_value) => { |
| 139 | + // TODO this operation should probably be feasible in one pass |
| 140 | + let scattered_then = scatter(then_scatter, then_value.as_ref())?; |
| 141 | + zip(then_scatter, &scattered_then, ¤t_value) |
| 142 | + } |
| 143 | + } |
| 144 | +} |
| 145 | + |
125 | 146 | impl CaseExpr { |
126 | 147 | /// Create a new CASE WHEN expression |
127 | 148 | pub fn try_new( |
@@ -286,64 +307,72 @@ impl CaseExpr { |
286 | 307 |
|
287 | 308 | // start with nulls as default output |
288 | 309 | let mut current_value = new_null_array(&return_type, batch.num_rows()); |
289 | | - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); |
290 | | - let mut remainder_count = batch.num_rows(); |
291 | | - for i in 0..self.when_then_expr.len() { |
292 | | - // If there are no rows left to process, break out of the loop early |
293 | | - if remainder_count == 0 { |
294 | | - break; |
295 | | - } |
296 | 310 |
|
| 311 | + let mut remainder_scatter = BooleanArray::from(vec![true; batch.num_rows()]); |
| 312 | + let mut remainder_batch = Cow::Borrowed(batch); |
| 313 | + |
| 314 | + for i in 0..self.when_then_expr.len() { |
| 315 | + // Evaluate the 'when' predicate for the remainder batch |
| 316 | + // This results in a boolean array with the same length as the remaining number of rows |
297 | 317 | let when_predicate = &self.when_then_expr[i].0; |
298 | | - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; |
299 | | - let when_value = when_value.into_array(batch.num_rows())?; |
| 318 | + let when_value = when_predicate.evaluate(&remainder_batch)?; |
| 319 | + let when_value = when_value.to_array(remainder_batch.num_rows())?; |
300 | 320 | let when_value = as_boolean_array(&when_value).map_err(|_| { |
301 | 321 | internal_datafusion_err!("WHEN expression did not return a BooleanArray") |
302 | 322 | })?; |
303 | | - // Treat 'NULL' as false value |
304 | | - let when_value = match when_value.null_count() { |
305 | | - 0 => Cow::Borrowed(when_value), |
306 | | - _ => Cow::Owned(prep_null_mask_filter(when_value)), |
307 | | - }; |
308 | | - // Make sure we only consider rows that have not been matched yet |
309 | | - let when_value = and(&when_value, &remainder)?; |
310 | 323 |
|
311 | | - // If the predicate did not match any rows, continue to the next branch immediately |
| 324 | + // If the 'when' predicate did not match any rows, continue to the next branch immediately |
312 | 325 | let when_match_count = when_value.true_count(); |
313 | 326 | if when_match_count == 0 { |
314 | 327 | continue; |
315 | 328 | } |
316 | 329 |
|
| 330 | + // Make sure 'NULL' is treated as false |
| 331 | + let when_value = match when_value.null_count() { |
| 332 | + 0 => Cow::Borrowed(when_value), |
| 333 | + _ => Cow::Owned(prep_null_mask_filter(when_value)), |
| 334 | + }; |
| 335 | + |
| 336 | + // Filter the remainder batch based on the 'when' value |
| 337 | + // This results in a batch containing only the rows that need to be evaluated |
| 338 | + // for the current branch |
| 339 | + let then_batch = filter_record_batch(&remainder_batch, &when_value)?; |
| 340 | + |
| 341 | + // Evaluate the then expression for the matching rows |
317 | 342 | let then_expression = &self.when_then_expr[i].1; |
318 | | - let then_value = then_expression.evaluate_selection(batch, &when_value)?; |
| 343 | + let then_value = then_expression.evaluate(&then_batch)?; |
319 | 344 |
|
320 | | - current_value = match then_value { |
321 | | - ColumnarValue::Scalar(ScalarValue::Null) => { |
322 | | - nullif(current_value.as_ref(), &when_value)? |
323 | | - } |
324 | | - ColumnarValue::Scalar(then_value) => { |
325 | | - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? |
326 | | - } |
327 | | - ColumnarValue::Array(then_value) => { |
328 | | - zip(&when_value, &then_value, ¤t_value)? |
329 | | - } |
330 | | - }; |
| 345 | + // Expand the 'when' match array using the 'remainder scatter' array |
| 346 | + // This results in a truthy boolean array than we can use to merge the |
| 347 | + // 'then' values with the `current_value` array. |
| 348 | + let then_merge = scatter(&remainder_scatter, when_value.as_ref())?; |
| 349 | + let then_merge = then_merge.as_boolean(); |
331 | 350 |
|
332 | | - // Succeed tuples should be filtered out for short-circuit evaluation, |
333 | | - // null values for the current when expr should be kept |
334 | | - remainder = and_not(&remainder, &when_value)?; |
335 | | - remainder_count -= when_match_count; |
| 351 | + // Merge the 'then' values with the `current_value` array |
| 352 | + current_value = merge_result(¤t_value, then_value, then_merge)?; |
| 353 | + |
| 354 | + // If the 'when' predicate matched all remaining row, there's nothing left to do so |
| 355 | + // we can return early |
| 356 | + if remainder_batch.num_rows() == when_match_count { |
| 357 | + return Ok(ColumnarValue::Array(current_value)); |
| 358 | + } |
| 359 | + |
| 360 | + // Clear the positions in 'remainder scatter' for which we just evaluated a value |
| 361 | + remainder_scatter = and_not(&remainder_scatter, then_merge)?; |
| 362 | + |
| 363 | + // Finally, prepare the remainder batch for the next branch |
| 364 | + let next_selection = not(&when_value)?; |
| 365 | + remainder_batch = |
| 366 | + Cow::Owned(filter_record_batch(&remainder_batch, &next_selection)?); |
336 | 367 | } |
337 | 368 |
|
| 369 | + // If we reached this point, some rows were left unmatched. |
| 370 | + // Check if those need to be evaluated using the 'else' expression. |
338 | 371 | if let Some(e) = self.else_expr() { |
339 | | - if remainder_count > 0 { |
340 | | - // keep `else_expr`'s data type and return type consistent |
341 | | - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; |
342 | | - let else_ = expr |
343 | | - .evaluate_selection(batch, &remainder)? |
344 | | - .into_array(batch.num_rows())?; |
345 | | - current_value = zip(&remainder, &else_, ¤t_value)?; |
346 | | - } |
| 372 | + // keep `else_expr`'s data type and return type consistent |
| 373 | + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; |
| 374 | + let else_value = expr.evaluate(&remainder_batch)?; |
| 375 | + current_value = merge_result(¤t_value, else_value, &remainder_scatter)?; |
347 | 376 | } |
348 | 377 |
|
349 | 378 | Ok(ColumnarValue::Array(current_value)) |
|
0 commit comments