@@ -21,8 +21,7 @@ use crate::PhysicalExpr;
2121use arrow:: array:: * ;
2222use arrow:: compute:: kernels:: zip:: zip;
2323use arrow:: compute:: {
24- interleave, is_null, not, nullif, prep_null_mask_filter, FilterBuilder ,
25- FilterPredicate ,
24+ is_null, not, nullif, prep_null_mask_filter, FilterBuilder , FilterPredicate ,
2625} ;
2726use arrow:: datatypes:: { DataType , Schema , UInt32Type } ;
2827use arrow:: error:: ArrowError ;
@@ -40,7 +39,7 @@ use std::{any::Any, sync::Arc};
4039type WhenThen = ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) ;
4140
4241#[ derive( Debug , Hash , PartialEq , Eq ) ]
43- enum EvalMethod {
42+ pub enum EvalMethod {
4443 /// CASE WHEN condition THEN result
4544 /// [WHEN ...]
4645 /// [ELSE result]
@@ -96,7 +95,7 @@ pub struct CaseExpr {
9695 /// Optional "else" expression
9796 else_expr : Option < Arc < dyn PhysicalExpr > > ,
9897 /// Evaluation method to use
99- eval_method : EvalMethod ,
98+ pub eval_method : EvalMethod ,
10099}
101100
102101impl std:: fmt:: Display for CaseExpr {
@@ -159,18 +158,27 @@ fn filter_array(
159158 filter. filter ( array)
160159}
161160
162- struct InterleaveBuilder {
163- indices : Vec < ( usize , usize ) > ,
164- arrays : Vec < ArrayRef > ,
161+ struct ResultBuilder {
162+ data_type : DataType ,
163+ // A Vec of partial results that should be merged. `partial_result_indices` contains
164+ // indexes into this vec.
165+ partial_results : Vec < ArrayData > ,
166+ // Indicates per result row from which array in `partial_results` a value should be taken.
167+ // The indexes in this array are offset by +1. The special value 0 indicates null values.
168+ partial_result_indices : Vec < usize > ,
169+ // An optional result that is the covering result for all rows.
170+ // This is used as an optimisation to avoid the cost of merging when all rows
171+ // evaluate to the same case branch.
172+ covering_result : Option < ColumnarValue > ,
165173}
166174
167- impl InterleaveBuilder {
175+ impl ResultBuilder {
168176 fn new ( data_type : & DataType , capacity : usize ) -> Self {
169- // By settings indices to (0, 0) every entry points to the single
170- // null value in the first array.
171177 Self {
172- indices : vec ! [ ( 0 , 0 ) ; capacity] ,
173- arrays : vec ! [ new_null_array( data_type, 1 ) ] ,
178+ data_type : data_type. clone ( ) ,
179+ partial_result_indices : vec ! [ 0 ; capacity] ,
180+ partial_results : vec ! [ ] ,
181+ covering_result : None ,
174182 }
175183 }
176184
@@ -183,56 +191,106 @@ impl InterleaveBuilder {
183191 /// If `value` is an array, the values from the array and the indices from `rows` will be
184192 /// processed pairwise.
185193 fn add_result ( & mut self , rows : & ArrayRef , value : ColumnarValue ) -> Result < ( ) > {
186- let array_index = self . arrays . len ( ) ;
187194 match value {
188195 ColumnarValue :: Array ( a) => {
189196 assert_eq ! ( a. len( ) , rows. len( ) ) ;
190-
191- self . arrays . push ( a) ;
192- for ( array_ix, row_ix) in rows
193- . as_primitive :: < UInt32Type > ( )
194- . values ( )
195- . iter ( )
196- . enumerate ( )
197- {
198- self . indices [ * row_ix as usize ] = ( array_index, array_ix) ;
197+ if rows. len ( ) == self . partial_result_indices . len ( ) {
198+ self . set_covering_result ( ColumnarValue :: Array ( a) ) ;
199+ } else {
200+ self . add_partial_result ( rows, a. to_data ( ) ) ;
199201 }
200202 }
201203 ColumnarValue :: Scalar ( s) => {
202- self . arrays . push ( s. to_array ( ) ?) ;
203- for row_ix in rows. as_primitive :: < UInt32Type > ( ) . values ( ) . iter ( ) {
204- self . indices [ * row_ix as usize ] = ( array_index, 0 ) ;
204+ if rows. len ( ) == self . partial_result_indices . len ( ) {
205+ self . set_covering_result ( ColumnarValue :: Scalar ( s) ) ;
206+ } else {
207+ self . add_partial_result (
208+ rows,
209+ s. to_array_of_size ( rows. len ( ) ) ?. to_data ( ) ,
210+ ) ;
205211 }
206212 }
207213 }
208214 Ok ( ( ) )
209215 }
210216
211- fn finish ( mut self ) -> Result < ColumnarValue > {
212- if self . arrays . len ( ) == 1 {
213- // The first array is always a single null value.
214- if self . indices . len ( ) == 1 {
215- // If there's only a single row, reuse the array
216- Ok ( ColumnarValue :: Array ( self . arrays . remove ( 0 ) ) )
217- } else {
218- // Otherwise make a new null array with the correct type and length
219- Ok ( ColumnarValue :: Array ( new_null_array (
220- self . arrays [ 0 ] . data_type ( ) ,
221- self . indices . len ( ) ,
222- ) ) )
217+ fn add_partial_result ( & mut self , rows : & ArrayRef , data : ArrayData ) {
218+ assert ! ( self . covering_result. is_none( ) ) ;
219+
220+ self . partial_results . push ( data) ;
221+ let array_index = self . partial_results . len ( ) ;
222+
223+ for row_ix in rows. as_primitive :: < UInt32Type > ( ) . values ( ) . iter ( ) {
224+ self . partial_result_indices [ * row_ix as usize ] = array_index;
225+ }
226+ }
227+
228+ fn set_covering_result ( & mut self , value : ColumnarValue ) {
229+ assert ! ( self . partial_results. is_empty( ) ) ;
230+ self . covering_result = Some ( value) ;
231+ }
232+
233+ fn finish ( self ) -> Result < ColumnarValue > {
234+ match self . covering_result {
235+ Some ( v) => {
236+ // If we have a covering result, we can just return it.
237+ Ok ( v)
223238 }
224- } else if self . arrays . len ( ) == 2
225- && !self . indices . iter ( ) . any ( |( array_ix, _) | * array_ix == 0 )
226- && self . arrays [ 1 ] . len ( ) == self . indices . len ( )
227- {
228- // There's only a single non-null array and no references to the null array.
229- // We can take a shortcut and return the non-null array directly.
230- Ok ( ColumnarValue :: Array ( self . arrays . remove ( 1 ) ) )
231- } else {
232- // Interleave arrays
233- let array_refs = self . arrays . iter ( ) . map ( |a| a. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
234- let interleaved_result = interleave ( & array_refs, & self . indices ) ?;
235- Ok ( ColumnarValue :: Array ( interleaved_result) )
239+ None => match self . partial_results . len ( ) {
240+ 0 => {
241+ // No covering result and no partial results.
242+ // This can happen for case expressions with no else branch where no rows
243+ // matched.
244+ Ok ( ColumnarValue :: Scalar ( ScalarValue :: try_new_null (
245+ & self . data_type ,
246+ ) ?) )
247+ }
248+ n => {
249+ // There are n partial results.
250+ // Merge into a single array.
251+
252+ let data_refs = self . partial_results . iter ( ) . collect ( ) ;
253+ let mut mutable = MutableArrayData :: new (
254+ data_refs,
255+ true ,
256+ self . partial_result_indices . len ( ) ,
257+ ) ;
258+
259+ // take_offsets keeps track of how many values have been taken from each array.
260+ let mut take_offsets = vec ! [ 0 ; n + 1 ] ;
261+
262+ let mut row_ix = 0 ;
263+ loop {
264+ let array_ix = self . partial_result_indices [ row_ix] ;
265+ row_ix += 1 ;
266+
267+ // Determine the length of the slice to take.
268+ let start_offset = take_offsets[ array_ix] ;
269+ let end_offset = start_offset + 1 ;
270+ while row_ix < self . partial_result_indices . len ( )
271+ && self . partial_result_indices [ row_ix] == array_ix
272+ {
273+ row_ix += 1 ;
274+ }
275+
276+ // Extend the result array either with nulls or with values from the array.
277+ if array_ix == 0 {
278+ mutable. extend_nulls ( end_offset - start_offset) ;
279+ } else {
280+ mutable. extend ( array_ix - 1 , start_offset, end_offset) ;
281+ }
282+
283+ // Update the take_offsets array.
284+ take_offsets[ array_ix] = end_offset;
285+
286+ if row_ix == self . partial_result_indices . len ( ) {
287+ break ;
288+ }
289+ }
290+
291+ Ok ( ColumnarValue :: Array ( make_array ( mutable. freeze ( ) ) ) )
292+ }
293+ } ,
236294 }
237295 }
238296}
@@ -313,12 +371,11 @@ impl CaseExpr {
313371 let optimize_filters = batch. num_columns ( ) > 1 ;
314372
315373 let return_type = self . data_type ( & batch. schema ( ) ) ?;
316- let mut interleave_builder =
317- InterleaveBuilder :: new ( & return_type, batch. num_rows ( ) ) ;
374+ let mut result_builder = ResultBuilder :: new ( & return_type, batch. num_rows ( ) ) ;
318375
319376 // `remainder_rows` contains the indices of the rows that need to be evaluated
320377 let mut remainder_rows: ArrayRef =
321- Arc :: new ( UInt32Array :: from_iter ( 0 ..batch. num_rows ( ) as u32 ) ) ;
378+ Arc :: new ( UInt32Array :: from_iter_values ( 0 ..batch. num_rows ( ) as u32 ) ) ;
322379 // `remainder_batch` contains the rows themselves that need to be evaluated
323380 let mut remainder_batch = Cow :: Borrowed ( batch) ;
324381
@@ -337,20 +394,20 @@ impl CaseExpr {
337394 let base_nulls = is_null ( base_value. as_ref ( ) ) ?;
338395 if base_nulls. true_count ( ) > 0 {
339396 // If there is an else expression, use that as the default value for the null rows
340- // Otherwise the default `null` value from the eInterleaveBuilder will be used.
397+ // Otherwise the default `null` value from the result builder will be used.
341398 if let Some ( e) = self . else_expr ( ) {
342399 let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) ) ?;
343400
344401 let nulls_filter = create_filter ( & base_nulls, optimize_filters) ;
345402 let nulls_batch = filter_record_batch ( & remainder_batch, & nulls_filter) ?;
346403 let nulls_rows = filter_array ( & remainder_rows, & nulls_filter) ?;
347404 let nulls_value = expr. evaluate ( & nulls_batch) ?;
348- interleave_builder . add_result ( & nulls_rows, nulls_value) ?;
405+ result_builder . add_result ( & nulls_rows, nulls_value) ?;
349406 }
350407
351408 // All base values were null, so we can return early
352409 if base_nulls. true_count ( ) == remainder_batch. num_rows ( ) {
353- return interleave_builder . finish ( ) ;
410+ return result_builder . finish ( ) ;
354411 }
355412
356413 // Remove the null rows from the remainder batch
@@ -397,14 +454,14 @@ impl CaseExpr {
397454
398455 let then_expression = & self . when_then_expr [ i] . 1 ;
399456 let then_value = then_expression. evaluate ( & then_batch) ?;
400- interleave_builder . add_result ( & then_rows, then_value) ?;
457+ result_builder . add_result ( & then_rows, then_value) ?;
401458
402459 // If the 'when' predicate matched all remaining row, there's nothing left to do so
403460 // we can return early
404461 if remainder_batch. num_rows ( ) == when_match_count
405462 || ( self . else_expr . is_none ( ) && i == self . when_then_expr . len ( ) - 1 )
406463 {
407- return interleave_builder . finish ( ) ;
464+ return result_builder . finish ( ) ;
408465 }
409466
410467 // Prepare the next when branch (or the else branch)
@@ -422,10 +479,10 @@ impl CaseExpr {
422479 // keep `else_expr`'s data type and return type consistent
423480 let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) ) ?;
424481 let else_value = expr. evaluate ( & remainder_batch) ?;
425- interleave_builder . add_result ( & remainder_rows, else_value) ?;
482+ result_builder . add_result ( & remainder_rows, else_value) ?;
426483 }
427484
428- interleave_builder . finish ( )
485+ result_builder . finish ( )
429486 }
430487
431488 /// This function evaluates the form of CASE where each WHEN expression is a boolean
@@ -439,8 +496,7 @@ impl CaseExpr {
439496 let optimize_filters = batch. num_columns ( ) > 1 ;
440497
441498 let return_type = self . data_type ( & batch. schema ( ) ) ?;
442- let mut interleave_builder =
443- InterleaveBuilder :: new ( & return_type, batch. num_rows ( ) ) ;
499+ let mut result_builder = ResultBuilder :: new ( & return_type, batch. num_rows ( ) ) ;
444500
445501 // `remainder_rows` contains the indices of the rows that need to be evaluated
446502 let mut remainder_rows: ArrayRef =
@@ -480,14 +536,14 @@ impl CaseExpr {
480536
481537 let then_expression = & self . when_then_expr [ i] . 1 ;
482538 let then_value = then_expression. evaluate ( & then_batch) ?;
483- interleave_builder . add_result ( & then_rows, then_value) ?;
539+ result_builder . add_result ( & then_rows, then_value) ?;
484540
485541 // If the 'when' predicate matched all remaining row, there's nothing left to do so
486542 // we can return early
487543 if remainder_batch. num_rows ( ) == when_match_count
488544 || ( self . else_expr . is_none ( ) && i == self . when_then_expr . len ( ) - 1 )
489545 {
490- return interleave_builder . finish ( ) ;
546+ return result_builder . finish ( ) ;
491547 }
492548
493549 // Prepare the next when branch (or the else branch)
@@ -504,10 +560,10 @@ impl CaseExpr {
504560 // keep `else_expr`'s data type and return type consistent
505561 let expr = try_cast ( Arc :: clone ( e) , & batch. schema ( ) , return_type. clone ( ) ) ?;
506562 let else_value = expr. evaluate ( & remainder_batch) ?;
507- interleave_builder . add_result ( & remainder_rows, else_value) ?;
563+ result_builder . add_result ( & remainder_rows, else_value) ?;
508564 }
509565
510- interleave_builder . finish ( )
566+ result_builder . finish ( )
511567 }
512568
513569 /// This function evaluates the specialized case of:
0 commit comments