diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs index 891d62fc3aa6..be2bbcafb674 100644 --- a/arrow-select/src/coalesce.rs +++ b/arrow-select/src/coalesce.rs @@ -93,6 +93,25 @@ use primitive::InProgressPrimitiveArray; /// assert!(coalescer.next_completed_batch().is_none()); /// ``` /// +/// Non-strict example (with_exact_size(false)) +/// ``` +/// # use arrow_array::record_batch; +/// # use arrow_select::coalesce::BatchCoalescer; +/// let batch1 = record_batch!(("a", Int32, [1, 2, 3])).unwrap(); +/// let batch2 = record_batch!(("a", Int32, [4, 5])).unwrap(); +/// +/// // Non-strict: produce batch once buffered >= target, batch may be larger than target +/// let mut coalescer = BatchCoalescer::new(batch1.schema(), 4).with_exact_size(false); +/// coalescer.push_batch(batch1).unwrap(); +/// // still < 4 rows buffered +/// assert!(coalescer.next_completed_batch().is_none()); +/// coalescer.push_batch(batch2).unwrap(); +/// // now buffered >= 4, non-strict mode emits whole buffered set (5 rows) +/// let finished = coalescer.next_completed_batch().unwrap(); +/// let expected = record_batch!(("a", Int32, [1, 2, 3, 4, 5])).unwrap(); +/// assert_eq!(finished, expected); +/// ``` +/// /// # Background /// /// Generally speaking, larger [`RecordBatch`]es are more efficient to process @@ -128,6 +147,14 @@ use primitive::InProgressPrimitiveArray; /// /// 2. The output is a sequence of batches, with all but the last being at exactly /// `target_batch_size` rows. +/// +/// Notes on `exact_size`: +/// +/// - `exact_size == true` (strict): output batches are produced so that all but +/// the final batch have exactly `target_batch_size` rows (default behavior). +/// - `exact_size == false` (non-strict, default for this crate): output batches +/// will be produced when the buffered rows are >= `target_batch_size`. The +/// produced batch may be larger than `target_batch_size` (i.e., size >= target). #[derive(Debug)] pub struct BatchCoalescer { /// The input schema @@ -142,6 +169,8 @@ pub struct BatchCoalescer { buffered_rows: usize, /// Completed batches completed: VecDeque, + /// Whether the output batches are guaranteed to be exactly `target_batch_size` + exact_size: bool, } impl BatchCoalescer { @@ -166,9 +195,24 @@ impl BatchCoalescer { // We will for sure store at least one completed batch completed: VecDeque::with_capacity(1), buffered_rows: 0, + exact_size: true, } } + /// Controls whether output batches are produced with exactly `target_batch_size`. + /// + /// When `exact_size == true` the coalescer will produce batches of exactly + /// `target_batch_size` rows (except possibly the final batch). This is the + /// historical behavior. + /// + /// When `exact_size == false` the coalescer will produce a batch once the + /// buffered rows >= `target_batch_size`, but the produced batch may be larger + /// than `target_batch_size` (i.e. batch size >= target). + pub fn with_exact_size(mut self, exact: bool) -> Self { + self.exact_size = exact; + self + } + /// Return the schema of the output batches pub fn schema(&self) -> SchemaRef { Arc::clone(&self.schema) @@ -241,48 +285,56 @@ impl BatchCoalescer { return Ok(()); } - // setup input rows + // set sources assert_eq!(arrays.len(), self.in_progress_arrays.len()); self.in_progress_arrays .iter_mut() .zip(arrays) - .for_each(|(in_progress, array)| { - in_progress.set_source(Some(array)); - }); + .for_each(|(in_progress, array)| in_progress.set_source(Some(array))); - // If pushing this batch would exceed the target batch size, - // finish the current batch and start a new one let mut offset = 0; - while num_rows > (self.target_batch_size - self.buffered_rows) { - let remaining_rows = self.target_batch_size - self.buffered_rows; - debug_assert!(remaining_rows > 0); - // Copy remaining_rows from each array - for in_progress in self.in_progress_arrays.iter_mut() { - in_progress.copy_rows(offset, remaining_rows)?; + if self.exact_size { + // Strict: produce exactly target-sized batches (except last) + while num_rows > (self.target_batch_size - self.buffered_rows) { + let remaining_rows = self.target_batch_size - self.buffered_rows; + debug_assert!(remaining_rows > 0); + for in_progress in self.in_progress_arrays.iter_mut() { + in_progress.copy_rows(offset, remaining_rows)?; + } + self.buffered_rows += remaining_rows; + offset += remaining_rows; + num_rows -= remaining_rows; + self.finish_buffered_batch()?; } - self.buffered_rows += remaining_rows; - offset += remaining_rows; - num_rows -= remaining_rows; - - self.finish_buffered_batch()?; - } + if num_rows > 0 { + for in_progress in self.in_progress_arrays.iter_mut() { + in_progress.copy_rows(offset, num_rows)?; + } + self.buffered_rows += num_rows; + } - // Add any the remaining rows to the buffer - self.buffered_rows += num_rows; - if num_rows > 0 { - for in_progress in self.in_progress_arrays.iter_mut() { - in_progress.copy_rows(offset, num_rows)?; + // ensure strict invariant: only finish when exactly full + if self.buffered_rows >= self.target_batch_size { + self.finish_buffered_batch()?; + } + } else { + // Non-strict: append all remaining rows; if buffered >= target, emit them + if num_rows > 0 { + for in_progress in self.in_progress_arrays.iter_mut() { + in_progress.copy_rows(offset, num_rows)?; + } + self.buffered_rows += num_rows; } - } - // If we have reached the target batch size, finalize the buffered batch - if self.buffered_rows >= self.target_batch_size { - self.finish_buffered_batch()?; + // If we've reached or exceeded target, emit the whole buffered set + if self.buffered_rows >= self.target_batch_size { + self.finish_buffered_batch()?; + } } - // clear in progress sources (to allow the memory to be freed) + // clear sources for in_progress in self.in_progress_arrays.iter_mut() { in_progress.set_source(None); } @@ -1314,4 +1366,144 @@ mod tests { let options = RecordBatchOptions::new().with_row_count(Some(row_count)); RecordBatch::try_new_with_options(schema, columns, &options).unwrap() } + + // Adding tests for exact_size setting to false + #[test] + fn test_non_coalesce_small_batches() { + // two small batches -> combined when buffered >= target + let batch1 = uint32_batch(0..3); // 3 rows + let batch2 = uint32_batch(3..5); // 2 rows + + let schema = Arc::clone(&batch1.schema()); + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false); + + // push first batch (3 rows) -> not enough + coalescer.push_batch(batch1).unwrap(); + assert!(coalescer.next_completed_batch().is_none()); + + // push second batch (2 rows) -> buffered becomes 5 >= 4, non-strict emits all 5 rows + coalescer.push_batch(batch2).unwrap(); + let out = coalescer + .next_completed_batch() + .expect("expected a completed batch"); + assert_eq!(out.num_rows(), 5); + + // check contents equal to concatenation of 0..5 + let expected = uint32_batch(0..5); + let actual = normalize_batch(out); + let expected = normalize_batch(expected); + assert_eq!(expected, actual); + } + + #[test] + fn test_non_strict_single_large_batch() { + // one large batch > target: in non-strict mode whole batch should be emitted + let batch = uint32_batch(0..4096); + let schema = Arc::clone(&batch.schema()); + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 1000).with_exact_size(false); + + coalescer.push_batch(batch).unwrap(); + let out = coalescer + .next_completed_batch() + .expect("expected a completed batch"); + assert_eq!(out.num_rows(), 4096); + + // compare to expected + let expected = uint32_batch(0..4096); + let actual = normalize_batch(out); + let expected = normalize_batch(expected); + assert_eq!(expected, actual); + } + + #[test] + fn test_strict_single_large_batch_multiple_outputs() { + // single large batch -> split into multiple exact target batches + let batch = uint32_batch(0..5000); + let schema = Arc::clone(&batch.schema()); + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 1000).with_exact_size(true); + + coalescer.push_batch(batch).unwrap(); + + // should emit 5 batches of 1000 each + let mut outputs = vec![]; + while let Some(b) = coalescer.next_completed_batch() { + outputs.push(b); + } + assert_eq!(outputs.len(), 5); + for (i, out) in outputs.into_iter().enumerate() { + assert_eq!(out.num_rows(), 1000); + let expected = uint32_batch((i * 1000) as u32..((i + 1) * 1000) as u32); + assert_eq!(normalize_batch(out), normalize_batch(expected)); + } + } + + #[test] + fn test_non_strict_multiple_emits_over_time() { + // multiple pushes that each eventually push buffered >= target and emit + let b1 = uint32_batch(0..3); // 3 + let b2 = uint32_batch(3..5); // 2 -> 3+2=5 emit (first) + let b3 = uint32_batch(5..8); // 3 + let b4 = uint32_batch(8..10); // 2 -> 3+2=5 emit (second) + + let schema = Arc::clone(&b1.schema()); + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false); + + coalescer.push_batch(b1).unwrap(); + assert!(coalescer.next_completed_batch().is_none()); + + coalescer.push_batch(b2).unwrap(); + let out1 = coalescer + .next_completed_batch() + .expect("expected first batch"); + assert_eq!(out1.num_rows(), 5); + assert_eq!(normalize_batch(out1), normalize_batch(uint32_batch(0..5))); + + coalescer.push_batch(b3).unwrap(); + assert!(coalescer.next_completed_batch().is_none()); + + coalescer.push_batch(b4).unwrap(); + let out2 = coalescer + .next_completed_batch() + .expect("expected second batch"); + assert_eq!(out2.num_rows(), 5); + assert_eq!(normalize_batch(out2), normalize_batch(uint32_batch(5..10))); + } + + #[test] + fn test_non_strict_large_then_more_outputs() { + // first push a large batch (should produce one big output), then push more small ones to produce another + let big = uint32_batch(0..5000); + let small1 = uint32_batch(5000..5002); // 2 + let small2 = uint32_batch(5002..5005); // 3 -> 2+3=5 >=4 emit + + let schema = Arc::clone(&big.schema()); + // Use small target (4) so that small1 + small2 will trigger an emit + let mut coalescer = BatchCoalescer::new(Arc::clone(&schema), 4).with_exact_size(false); + + // push big: non-strict mode should emit the whole big batch (5000 rows) + coalescer.push_batch(big).unwrap(); + let out_big = coalescer + .next_completed_batch() + .expect("expected big batch"); + assert_eq!(out_big.num_rows(), 5000); + assert_eq!( + normalize_batch(out_big), + normalize_batch(uint32_batch(0..5000)) + ); + + // push small1 (2 rows) -> not enough yet + coalescer.push_batch(small1).unwrap(); + assert!(coalescer.next_completed_batch().is_none()); + + // push small2 (3 rows) -> now buffered = 2 + 3 = 5 >= 4, non-strict emits all 5 rows + coalescer.push_batch(small2).unwrap(); + let out_small = coalescer + .next_completed_batch() + .expect("expected small batch"); + assert_eq!(out_small.num_rows(), 5); + assert_eq!( + normalize_batch(out_small), + normalize_batch(uint32_batch(5000..5005)) + ); + } } diff --git a/arrow/benches/coalesce_kernels.rs b/arrow/benches/coalesce_kernels.rs index 941882c70e8d..0d255243b077 100644 --- a/arrow/benches/coalesce_kernels.rs +++ b/arrow/benches/coalesce_kernels.rs @@ -232,7 +232,7 @@ fn filter_streams( ) { let schema = data_stream.schema(); let batch_size = data_stream.batch_size(); - let mut coalescer = BatchCoalescer::new(Arc::clone(schema), batch_size); + let mut coalescer = BatchCoalescer::new(Arc::clone(schema), batch_size).with_exact_size(false); while num_output_batches > 0 { let filter = filter_stream.next_filter();