From 45ca98549f0788acad4711eabf0d187a58cfd9ca Mon Sep 17 00:00:00 2001 From: RinChanNOW Date: Mon, 18 Dec 2023 10:12:18 +0800 Subject: [PATCH] refactor: abstract a common merger to do merge sort. (#14020) * refactor: abstract a common merger to do merge sort. * Replace the merger in `TransformSortSpill` with `HeapMerger`. * Replace the merger in `TransformSortMerge` with `HeapMerger`. * Ensure merger has at least two input streams. * Refactor `SortedStream` trait, output block along with order column. * Add `limit` to `HeapMerger`. * Refactor `HeapMerger`. * Replace the merger in `TransformMultiSortMerge` with `HeapMerger`. * Fix and add assertions. * Improve. * Recover `TransformMultiSortMerge`. * Refactor codes. --- Cargo.lock | 2 + src/query/expression/src/block.rs | 7 + src/query/pipeline/transforms/Cargo.toml | 4 + src/query/pipeline/transforms/src/lib.rs | 1 + .../src/processors/transforms/sort/cursor.rs | 7 + .../src/processors/transforms/sort/merger.rs | 321 ++++++++++++++++++ .../src/processors/transforms/sort/mod.rs | 2 + .../processors/transforms/sort/rows/common.rs | 6 +- .../processors/transforms/sort/rows/mod.rs | 16 +- .../processors/transforms/sort/rows/simple.rs | 11 +- .../src/processors/transforms/sort/utils.rs | 38 +++ .../transforms/transform_multi_sort_merge.rs | 12 +- .../transforms/transform_sort_merge.rs | 154 ++++----- .../transforms/transform_sort_merge_base.rs | 57 ++-- .../pipeline/transforms/tests/it/main.rs | 15 + .../pipeline/transforms/tests/it/merger.rs | 262 ++++++++++++++ .../src/pipelines/builders/builder_sort.rs | 55 +-- .../transforms/transform_sort_spill.rs | 174 ++-------- .../executor/physical_plans/physical_sort.rs | 5 +- 19 files changed, 825 insertions(+), 324 deletions(-) create mode 100644 src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs create mode 100644 src/query/pipeline/transforms/tests/it/main.rs create mode 100644 src/query/pipeline/transforms/tests/it/merger.rs diff --git a/Cargo.lock b/Cargo.lock index bc40520848dc..e5ac5940a8cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3233,8 +3233,10 @@ dependencies = [ "databend-common-expression", "databend-common-pipeline-core", "databend-common-profile", + "itertools 0.10.5", "jsonb 0.3.0 (git+https://github.com/datafuselabs/jsonb?rev=582c139)", "match-template", + "rand 0.8.5", "serde", "typetag", ] diff --git a/src/query/expression/src/block.rs b/src/query/expression/src/block.rs index 1e46884df4bd..029239fc5c70 100644 --- a/src/query/expression/src/block.rs +++ b/src/query/expression/src/block.rs @@ -553,6 +553,13 @@ impl DataBlock { } DataBlock::new_with_meta(columns, self.num_rows, self.meta) } + + #[inline] + pub fn get_last_column(&self) -> &Column { + debug_assert!(!self.columns.is_empty()); + debug_assert!(self.columns.last().unwrap().value.as_column().is_some()); + self.columns.last().unwrap().value.as_column().unwrap() + } } impl TryFrom for ArrowChunk { diff --git a/src/query/pipeline/transforms/Cargo.toml b/src/query/pipeline/transforms/Cargo.toml index 0630c15e4654..3fff363caa10 100644 --- a/src/query/pipeline/transforms/Cargo.toml +++ b/src/query/pipeline/transforms/Cargo.toml @@ -22,5 +22,9 @@ match-template = { workspace = true } serde = { workspace = true } typetag = { workspace = true } +[dev-dependencies] +itertools = { workspace = true } +rand = { workspace = true } + [package.metadata.cargo-machete] ignored = ["match-template"] diff --git a/src/query/pipeline/transforms/src/lib.rs b/src/query/pipeline/transforms/src/lib.rs index addeeb815e23..339e62d63f84 100644 --- a/src/query/pipeline/transforms/src/lib.rs +++ b/src/query/pipeline/transforms/src/lib.rs @@ -15,5 +15,6 @@ #![feature(core_intrinsics)] #![feature(int_roundings)] #![feature(binary_heap_as_slice)] +#![feature(let_chains)] pub mod processors; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/cursor.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/cursor.rs index 85804e80363a..4c6c422d77ea 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/cursor.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/cursor.rs @@ -14,6 +14,8 @@ use std::cmp::Ordering; +use databend_common_expression::Column; + use super::rows::Rows; /// A cursor point to a certain row in a data block. @@ -64,6 +66,11 @@ impl Cursor { pub fn num_rows(&self) -> usize { self.num_rows } + + #[inline] + pub fn to_column(&self) -> Column { + self.rows.to_column() + } } impl Ord for Cursor { diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs new file mode 100644 index 000000000000..a2ae12462731 --- /dev/null +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/merger.rs @@ -0,0 +1,321 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::collections::VecDeque; +use std::sync::Arc; + +use databend_common_exception::Result; +use databend_common_expression::Column; +use databend_common_expression::DataBlock; +use databend_common_expression::DataSchemaRef; +use databend_common_expression::SortColumnDescription; + +use super::utils::find_bigger_child_of_root; +use super::Cursor; +use super::Rows; + +#[async_trait::async_trait] +pub trait SortedStream { + /// Returns the next block with the order column and if it is pending. + /// + /// If the block is [None] and it's not pending, it means the stream is finished. + /// If the block is [None] but it's pending, it means the stream is not finished yet. + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { + Ok((None, false)) + } + + /// The async version of `next`. + async fn async_next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { + self.next() + } +} + +/// A merge sort operator to merge multiple sorted streams and output one sorted stream. +pub struct HeapMerger +where + R: Rows, + S: SortedStream, +{ + schema: DataSchemaRef, + sort_desc: Arc>, + unsorted_streams: Vec, + heap: BinaryHeap>>, + buffer: Vec, + pending_streams: VecDeque, + batch_size: usize, + limit: Option, + + temp_sorted_num_rows: usize, + temp_output_indices: Vec<(usize, usize, usize)>, + temp_sorted_blocks: Vec, +} + +impl HeapMerger +where + R: Rows, + S: SortedStream + Send, +{ + pub fn create( + schema: DataSchemaRef, + streams: Vec, + sort_desc: Arc>, + batch_size: usize, + limit: Option, + ) -> Self { + // We only create a merger when there are at least two streams. + debug_assert!(streams.len() > 1, "streams.len() = {}", streams.len()); + + let heap = BinaryHeap::with_capacity(streams.len()); + let buffer = vec![DataBlock::empty_with_schema(schema.clone()); streams.len()]; + let pending_stream = (0..streams.len()).collect(); + + Self { + schema, + unsorted_streams: streams, + heap, + buffer, + batch_size, + limit, + sort_desc, + pending_streams: pending_stream, + temp_sorted_num_rows: 0, + temp_output_indices: vec![], + temp_sorted_blocks: vec![], + } + } + + #[inline(always)] + pub fn is_finished(&self) -> bool { + (self.heap.is_empty() && !self.has_pending_stream() && self.temp_sorted_num_rows == 0) + || self.limit == Some(0) + } + + #[inline(always)] + pub fn has_pending_stream(&self) -> bool { + !self.pending_streams.is_empty() + } + + // This method can only be called when there is no data of the stream in the heap. + pub async fn async_poll_pending_stream(&mut self) -> Result<()> { + let mut continue_pendings = Vec::new(); + while let Some(i) = self.pending_streams.pop_front() { + debug_assert!(self.buffer[i].is_empty()); + let (input, pending) = self.unsorted_streams[i].async_next().await?; + if pending { + continue_pendings.push(i); + continue; + } + if let Some((block, col)) = input { + let rows = R::from_column(&col, &self.sort_desc)?; + let cursor = Cursor::new(i, rows); + self.heap.push(Reverse(cursor)); + self.buffer[i] = block; + } + } + self.pending_streams.extend(continue_pendings); + Ok(()) + } + + #[inline] + pub fn poll_pending_stream(&mut self) -> Result<()> { + let mut continue_pendings = Vec::new(); + while let Some(i) = self.pending_streams.pop_front() { + debug_assert!(self.buffer[i].is_empty()); + let (input, pending) = self.unsorted_streams[i].next()?; + if pending { + continue_pendings.push(i); + continue; + } + if let Some((block, col)) = input { + let rows = R::from_column(&col, &self.sort_desc)?; + let cursor = Cursor::new(i, rows); + self.heap.push(Reverse(cursor)); + self.buffer[i] = block; + } + } + self.pending_streams.extend(continue_pendings); + Ok(()) + } + + /// To evaluate the current cursor, and update the top of the heap if necessary. + /// This method can only be called when iterating the heap. + /// + /// Return `true` if the batch is full (need to output). + #[inline(always)] + fn evaluate_cursor(&mut self, mut cursor: Cursor) -> bool { + let max_rows = self.limit.unwrap_or(self.batch_size).min(self.batch_size); + if self.heap.len() == 1 { + let start = cursor.row_index; + let count = (cursor.num_rows() - start).min(max_rows - self.temp_sorted_num_rows); + self.temp_sorted_num_rows += count; + cursor.row_index += count; + self.temp_output_indices + .push((cursor.input_index, start, count)); + } else { + let next_cursor = &find_bigger_child_of_root(&self.heap).0; + if cursor.last().le(&next_cursor.current()) { + // Short Path: + // If the last row of current block is smaller than the next cursor, + // we can drain the whole block. + let start = cursor.row_index; + let count = (cursor.num_rows() - start).min(max_rows - self.temp_sorted_num_rows); + self.temp_sorted_num_rows += count; + cursor.row_index += count; + self.temp_output_indices + .push((cursor.input_index, start, count)); + } else { + // We copy current cursor for advancing, + // and we will use this copied cursor to update the top of the heap at last + // (let heap adjust itself without popping and pushing any element). + let start = cursor.row_index; + while !cursor.is_finished() + && cursor.le(next_cursor) + && self.temp_sorted_num_rows < max_rows + { + // If the cursor is smaller than the next cursor, don't need to push the cursor back to the heap. + self.temp_sorted_num_rows += 1; + cursor.advance(); + } + self.temp_output_indices.push(( + cursor.input_index, + start, + cursor.row_index - start, + )); + } + } + + if !cursor.is_finished() { + // Update the top of the heap. + // `self.heap.peek_mut` will return a `PeekMut` object which allows us to modify the top element of the heap. + // The heap will adjust itself automatically when the `PeekMut` object is dropped (RAII). + self.heap.peek_mut().unwrap().0 = cursor; + } else { + // Pop the current `cursor`. + self.heap.pop(); + // We have read all rows of this block, need to release the old memory and read a new one. + let temp_block = DataBlock::take_by_slices_limit_from_blocks( + &self.buffer, + &self.temp_output_indices, + None, + ); + self.buffer[cursor.input_index] = DataBlock::empty_with_schema(self.schema.clone()); + self.temp_sorted_blocks.push(temp_block); + self.temp_output_indices.clear(); + self.pending_streams.push_back(cursor.input_index); + } + + debug_assert!(self.temp_sorted_num_rows <= max_rows); + self.temp_sorted_num_rows == max_rows + } + + fn build_output(&mut self) -> Result { + if !self.temp_output_indices.is_empty() { + let block = DataBlock::take_by_slices_limit_from_blocks( + &self.buffer, + &self.temp_output_indices, + None, + ); + self.temp_sorted_blocks.push(block); + } + let block = DataBlock::concat(&self.temp_sorted_blocks)?; + + debug_assert_eq!(block.num_rows(), self.temp_sorted_num_rows); + debug_assert!(block.num_rows() <= self.batch_size); + + self.limit = self.limit.map(|limit| limit - self.temp_sorted_num_rows); + self.temp_sorted_blocks.clear(); + self.temp_output_indices.clear(); + self.temp_sorted_num_rows = 0; + + Ok(block) + } + + /// Returns the next sorted block and if it is pending. + /// + /// If the block is [None], it means the merger is finished or pending (has pending streams). + pub fn next_block(&mut self) -> Result> { + if self.is_finished() { + return Ok(None); + } + + if self.has_pending_stream() { + self.poll_pending_stream()?; + if self.has_pending_stream() { + return Ok(None); + } + } + + // No pending streams now. + if self.heap.is_empty() { + return if self.temp_sorted_num_rows > 0 { + Ok(Some(self.build_output()?)) + } else { + Ok(None) + }; + } + + while let Some(Reverse(cursor)) = self.heap.peek() { + if self.evaluate_cursor(cursor.clone()) { + break; + } + if self.has_pending_stream() { + self.poll_pending_stream()?; + if self.has_pending_stream() { + return Ok(None); + } + } + } + + Ok(Some(self.build_output()?)) + } + + /// The async version of `next_block`. + pub async fn async_next_block(&mut self) -> Result> { + if self.is_finished() { + return Ok(None); + } + + if self.has_pending_stream() { + self.async_poll_pending_stream().await?; + if self.has_pending_stream() { + return Ok(None); + } + } + + // No pending streams now. + if self.heap.is_empty() { + return if self.temp_sorted_num_rows > 0 { + Ok(Some(self.build_output()?)) + } else { + Ok(None) + }; + } + + while let Some(Reverse(cursor)) = self.heap.peek() { + if self.evaluate_cursor(cursor.clone()) { + break; + } + if self.has_pending_stream() { + self.async_poll_pending_stream().await?; + if self.has_pending_stream() { + return Ok(None); + } + } + } + + Ok(Some(self.build_output()?)) + } +} diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/mod.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/mod.rs index a417cba75819..edcf7f7c668a 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/mod.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/mod.rs @@ -13,10 +13,12 @@ // limitations under the License. mod cursor; +mod merger; mod rows; mod spill; pub mod utils; pub use cursor::*; +pub use merger::*; pub use rows::*; pub use spill::*; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs index 6405fb0b30a6..fcfbe291073d 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/common.rs @@ -48,9 +48,13 @@ impl Rows for StringColumn { Column::String(self.clone()) } - fn from_column(col: Column, _: &[SortColumnDescription]) -> Option { + fn try_from_column(col: &Column, _: &[SortColumnDescription]) -> Option { col.as_string().cloned() } + + fn data_type() -> DataType { + DataType::String + } } impl RowConverter for CommonRowConverter { diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs index 9437f466c7e0..68892847513c 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/mod.rs @@ -16,7 +16,9 @@ mod common; mod simple; pub use common::*; +use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_expression::types::DataType; use databend_common_expression::BlockEntry; use databend_common_expression::Column; use databend_common_expression::DataSchemaRef; @@ -44,7 +46,19 @@ where Self: Sized + Clone fn len(&self) -> usize; fn row(&self, index: usize) -> Self::Item<'_>; fn to_column(&self) -> Column; - fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option; + + fn from_column(col: &Column, desc: &[SortColumnDescription]) -> Result { + Self::try_from_column(col, desc).ok_or_else(|| { + ErrorCode::BadDataValueType(format!( + "Order column type mismatched. Expecetd {} but got {}", + Self::data_type(), + col.data_type() + )) + }) + } + fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option; + + fn data_type() -> DataType; fn is_empty(&self) -> bool { self.len() == 0 diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs index 88367d1f9e70..19314173bbfc 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/rows/simple.rs @@ -18,6 +18,7 @@ use std::marker::PhantomData; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::ArgType; +use databend_common_expression::types::DataType; use databend_common_expression::types::DateType; use databend_common_expression::types::StringType; use databend_common_expression::types::TimestampType; @@ -93,7 +94,7 @@ where impl Rows for SimpleRows where - T: ValueType, + T: ArgType, T::Scalar: Ord, { type Item<'a> = SimpleRow; @@ -114,13 +115,17 @@ where T::upcast_column(self.inner.clone()) } - fn from_column(col: Column, desc: &[SortColumnDescription]) -> Option { - let inner = T::try_downcast_column(&col)?; + fn try_from_column(col: &Column, desc: &[SortColumnDescription]) -> Option { + let inner = T::try_downcast_column(col)?; Some(Self { inner, desc: !desc[0].asc, }) } + + fn data_type() -> DataType { + T::data_type() + } } pub type DateConverter = SimpleRowConverter; diff --git a/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs b/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs index 56c314935f55..9e466e14138b 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/sort/utils.rs @@ -14,6 +14,15 @@ use std::collections::BinaryHeap; +use databend_common_expression::types::DataType; +use databend_common_expression::DataField; +use databend_common_expression::DataSchema; +use databend_common_expression::DataSchemaRef; +use databend_common_expression::DataSchemaRefExt; +use databend_common_expression::SortColumnDescription; + +pub const ORDER_COL_NAME: &str = "_order_col"; + /// Find the bigger child of the root of the heap. #[inline(always)] pub fn find_bigger_child_of_root(heap: &BinaryHeap) -> &T { @@ -25,3 +34,32 @@ pub fn find_bigger_child_of_root(heap: &BinaryHeap) -> &T { (&slice[1]).max(&slice[2]) } } + +#[inline(always)] +fn order_field_type(schema: &DataSchema, desc: &[SortColumnDescription]) -> DataType { + debug_assert!(!desc.is_empty()); + if desc.len() == 1 { + let order_by_field = schema.field(desc[0].offset); + if matches!( + order_by_field.data_type(), + DataType::Number(_) | DataType::Date | DataType::Timestamp | DataType::String + ) { + return order_by_field.data_type().clone(); + } + } + DataType::String +} + +#[inline(always)] +pub fn add_order_field(schema: DataSchemaRef, desc: &[SortColumnDescription]) -> DataSchemaRef { + if let Some(f) = schema.fields.last() && f.name() == ORDER_COL_NAME { + schema + } else { + let mut fields = schema.fields().clone(); + fields.push(DataField::new( + ORDER_COL_NAME, + order_field_type(&schema, desc), + )); + DataSchemaRefExt::create(fields) + } +} diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs index 5de1dd6daea7..69124479fbf0 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_multi_sort_merge.rs @@ -512,17 +512,7 @@ where R: Rows + Send + 'static continue; } let mut block = block.convert_to_full(); - let order_col = block - .columns() - .last() - .unwrap() - .value - .as_column() - .unwrap() - .clone(); - let rows = R::from_column(order_col, &self.sort_desc).ok_or_else(|| { - ErrorCode::BadDataValueType("Order column type mismatched.") - })?; + let rows = R::from_column(block.get_last_column(), &self.sort_desc)?; // Remove the order column if self.remove_order_col { block.pop_columns(1); diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs index 729b279ed117..cc2bee874477 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge.rs @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp::Reverse; -use std::collections::BinaryHeap; use std::intrinsics::unlikely; +use std::marker::PhantomData; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::sync::Arc; @@ -24,16 +23,18 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::row::RowConverter as CommonConverter; use databend_common_expression::BlockMetaInfo; +use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::DataSchemaRef; use databend_common_expression::SortColumnDescription; -use super::sort::utils::find_bigger_child_of_root; use super::sort::CommonRows; use super::sort::Cursor; use super::sort::DateConverter; use super::sort::DateRows; +use super::sort::HeapMerger; use super::sort::Rows; +use super::sort::SortedStream; use super::sort::StringConverter; use super::sort::StringRows; use super::sort::TimestampConverter; @@ -52,9 +53,11 @@ const SPILL_BATCH_BYTES_SIZE: usize = 8 * 1024 * 1024; /// /// For merge sort with limit, see [`super::transform_sort_merge_limit`] pub struct TransformSortMerge { + schema: DataSchemaRef, + sort_desc: Arc>, + block_size: usize, - heap: BinaryHeap>>, - buffer: Vec, + buffer: Vec>, aborting: Arc, // The following fields are used for spilling. @@ -71,18 +74,23 @@ pub struct TransformSortMerge { spill_batch_size: usize, /// The number of spilled blocks in each merge of the spill processor. spill_num_merge: usize, + + _r: PhantomData, } impl TransformSortMerge { pub fn create( + schema: DataSchemaRef, + sort_desc: Arc>, block_size: usize, max_memory_usage: usize, spilling_bytes_threshold: usize, ) -> Self { let may_spill = max_memory_usage != 0 && spilling_bytes_threshold != 0; TransformSortMerge { + schema, + sort_desc, block_size, - heap: BinaryHeap::new(), buffer: vec![], aborting: Arc::new(AtomicBool::new(false)), may_spill, @@ -92,6 +100,7 @@ impl TransformSortMerge { num_rows: 0, spill_batch_size: 0, spill_num_merge: 0, + _r: PhantomData, } } } @@ -112,8 +121,7 @@ impl MergeSort for TransformSortMerge { self.num_bytes += block.memory_size(); self.num_rows += block.num_rows(); - self.buffer.push(block); - self.heap.push(Reverse(init_cursor)); + self.buffer.push(Some((block, init_cursor.to_column()))); if self.may_spill && (self.num_bytes >= self.spilling_bytes_threshold @@ -130,9 +138,9 @@ impl MergeSort for TransformSortMerge { if self.spill_num_merge > 0 { debug_assert!(self.spill_batch_size > 0); // Make the last block as a big memory block. - self.drain_heap(usize::MAX) + self.merge_sort(usize::MAX) } else { - self.drain_heap(self.block_size) + self.merge_sort(self.block_size) } } @@ -159,7 +167,7 @@ impl TransformSortMerge { debug_assert!(self.spill_num_merge > 0); } - let mut blocks = self.drain_heap(self.spill_batch_size)?; + let mut blocks = self.merge_sort(self.spill_batch_size)?; if let Some(b) = blocks.first_mut() { b.replace_meta(spill_meta); } @@ -167,7 +175,6 @@ impl TransformSortMerge { b.replace_meta(Box::new(SortSpillMeta {})); } - debug_assert!(self.heap.is_empty()); self.num_rows = 0; self.num_bytes = 0; self.buffer.clear(); @@ -175,96 +182,60 @@ impl TransformSortMerge { Ok(blocks) } - fn drain_heap(&mut self, batch_size: usize) -> Result> { - // TODO: the codes is highly duplicated with the codes in `transform_sort_spill.rs`, - // need to refactor and merge them later. - if self.num_rows == 0 { + fn merge_sort(&mut self, batch_size: usize) -> Result> { + if self.buffer.is_empty() { return Ok(vec![]); } - let output_block_num = self.num_rows.div_ceil(batch_size); - let mut output_blocks = Vec::with_capacity(output_block_num); - let mut output_indices = Vec::with_capacity(output_block_num); - - // 1. Drain the heap - let mut temp_num_rows = 0; - let mut temp_indices = Vec::new(); - while let Some(Reverse(cursor)) = self.heap.peek() { - if unlikely(self.aborting.load(Ordering::Relaxed)) { - return Err(ErrorCode::AbortedQuery( - "Aborted query, because the server is shutting down or the query was killed.", - )); - } + let size_hint = self.num_rows.div_ceil(batch_size); - let mut cursor = cursor.clone(); - if self.heap.len() == 1 { - let start = cursor.row_index; - let count = (cursor.num_rows() - start).min(batch_size - temp_num_rows); - temp_num_rows += count; - cursor.row_index += count; - temp_indices.push((cursor.input_index, start, count)); - } else { - let next_cursor = &find_bigger_child_of_root(&self.heap).0; - if cursor.last().le(&next_cursor.current()) { - // Short Path: - // If the last row of current block is smaller than the next cursor, - // we can drain the whole block. - let start = cursor.row_index; - let count = (cursor.num_rows() - start).min(batch_size - temp_num_rows); - temp_num_rows += count; - cursor.row_index += count; - temp_indices.push((cursor.input_index, start, count)); - } else { - // We copy current cursor for advancing, - // and we will use this copied cursor to update the top of the heap at last - // (let heap adjust itself without popping and pushing any element). - let start = cursor.row_index; - while !cursor.is_finished() - && cursor.le(next_cursor) - && temp_num_rows < batch_size - { - // If the cursor is smaller than the next cursor, don't need to push the cursor back to the heap. - temp_num_rows += 1; - cursor.advance(); - } - temp_indices.push((cursor.input_index, start, cursor.row_index - start)); - } + if self.buffer.len() == 1 { + // If there is only one block, we don't need to merge. + let (block, _) = self.buffer.pop().unwrap().unwrap(); + let num_rows = block.num_rows(); + if size_hint == 1 { + return Ok(vec![block]); } - - if !cursor.is_finished() { - // Update the top of the heap. - // `self.heap.peek_mut` will return a `PeekMut` object which allows us to modify the top element of the heap. - // The heap will adjust itself automatically when the `PeekMut` object is dropped (RAII). - self.heap.peek_mut().unwrap().0 = cursor; - } else { - // Pop the current `cursor`. - self.heap.pop(); + let mut result = Vec::with_capacity(size_hint); + for i in 0..size_hint { + let start = i * batch_size; + let end = ((i + 1) * batch_size).min(num_rows); + let block = block.slice(start..end); + result.push(block); } - - if temp_num_rows == batch_size { - output_indices.push(temp_indices.clone()); - temp_indices.clear(); - temp_num_rows = 0; - } - } - - if !temp_indices.is_empty() { - output_indices.push(temp_indices); + return Ok(result); } - // 2. Build final blocks from `output_indices`. - for indices in output_indices { + let streams = self.buffer.drain(..).collect::>(); + let mut result = Vec::with_capacity(size_hint); + let mut merger = HeapMerger::::create( + self.schema.clone(), + streams, + self.sort_desc.clone(), + batch_size, + None, + ); + + while let Some(block) = merger.next_block()? { if unlikely(self.aborting.load(Ordering::Relaxed)) { return Err(ErrorCode::AbortedQuery( "Aborted query, because the server is shutting down or the query was killed.", )); } - - let block = DataBlock::take_by_slices_limit_from_blocks(&self.buffer, &indices, None); - output_blocks.push(block); + result.push(block); } - Ok(output_blocks) + debug_assert!(merger.is_finished()); + + Ok(result) + } +} + +type BlockStream = Option<(DataBlock, Column)>; + +impl SortedStream for BlockStream { + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { + Ok((self.take(), false)) } } @@ -284,17 +255,18 @@ pub(super) type MergeSortCommon = TransformSortMergeBase; pub fn sort_merge( - data_schema: DataSchemaRef, + schema: DataSchemaRef, block_size: usize, sort_desc: Vec, data_blocks: Vec, ) -> Result> { + let sort_desc = Arc::new(sort_desc); let mut processor = MergeSortCommon::try_create( - data_schema, - Arc::new(sort_desc), + schema.clone(), + sort_desc.clone(), false, false, - MergeSortCommonImpl::create(block_size, 0, 0), + MergeSortCommonImpl::create(schema, sort_desc, block_size, 0, 0), )?; for block in data_blocks { processor.transform(block)?; diff --git a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs index b16b7e1eeac3..3a636a284080 100644 --- a/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs +++ b/src/query/pipeline/transforms/src/processors/transforms/transform_sort_merge_base.rs @@ -15,7 +15,6 @@ use std::marker::PhantomData; use std::sync::Arc; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; @@ -55,6 +54,7 @@ use super::MergeSortTimestamp; use super::MergeSortTimestampImpl; use super::TransformSortMerge; use super::TransformSortMergeLimit; +use crate::processors::sort::utils::ORDER_COL_NAME; pub enum Status { /// Continue to add blocks. @@ -134,16 +134,7 @@ where fn transform(&mut self, mut block: DataBlock) -> Result> { let rows = if self.order_col_generated { - let order_col = block - .columns() - .last() - .unwrap() - .value - .as_column() - .unwrap() - .clone(); - let rows = R::from_column(order_col, &self.sort_desc) - .ok_or_else(|| ErrorCode::BadDataValueType("Order column type mismatched."))?; + let rows = R::from_column(block.get_last_column(), &self.sort_desc)?; if !self.output_order_col { // The next processor could be a sort spill processor which need order column. // And the order column will be removed in that processor. @@ -250,6 +241,12 @@ impl TransformSortMergeBuilder { } pub fn build(self) -> Result> { + debug_assert!(if self.output_order_col { + self.schema.has_field(ORDER_COL_NAME) + } else { + !self.schema.has_field(ORDER_COL_NAME) + }); + if self.limit.is_some() { self.build_sort_merge_limit() } else { @@ -283,14 +280,16 @@ impl TransformSortMergeBuilder { SimpleRows>, SimpleRowConverter>, >::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, TransformSortMerge::create( + schema, + sort_desc, block_size, max_memory_usage, - spilling_bytes_threshold_per_core + spilling_bytes_threshold_per_core, ), )?, ), @@ -299,11 +298,13 @@ impl TransformSortMergeBuilder { input, output, MergeSortDate::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, MergeSortDateImpl::create( + schema, + sort_desc, block_size, max_memory_usage, spilling_bytes_threshold_per_core, @@ -314,11 +315,13 @@ impl TransformSortMergeBuilder { input, output, MergeSortTimestamp::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, MergeSortTimestampImpl::create( + schema, + sort_desc, block_size, max_memory_usage, spilling_bytes_threshold_per_core, @@ -329,11 +332,13 @@ impl TransformSortMergeBuilder { input, output, MergeSortString::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, MergeSortStringImpl::create( + schema, + sort_desc, block_size, max_memory_usage, spilling_bytes_threshold_per_core, @@ -344,11 +349,13 @@ impl TransformSortMergeBuilder { input, output, MergeSortCommon::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, MergeSortCommonImpl::create( + schema, + sort_desc, block_size, max_memory_usage, spilling_bytes_threshold_per_core, @@ -361,11 +368,13 @@ impl TransformSortMergeBuilder { input, output, MergeSortCommon::try_create( - schema, - sort_desc, + schema.clone(), + sort_desc.clone(), order_col_generated, output_order_col, MergeSortCommonImpl::create( + schema, + sort_desc, block_size, max_memory_usage, spilling_bytes_threshold_per_core, diff --git a/src/query/pipeline/transforms/tests/it/main.rs b/src/query/pipeline/transforms/tests/it/main.rs new file mode 100644 index 000000000000..dc8da8592a46 --- /dev/null +++ b/src/query/pipeline/transforms/tests/it/main.rs @@ -0,0 +1,15 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +mod merger; diff --git a/src/query/pipeline/transforms/tests/it/merger.rs b/src/query/pipeline/transforms/tests/it/merger.rs new file mode 100644 index 000000000000..18c9a81453db --- /dev/null +++ b/src/query/pipeline/transforms/tests/it/merger.rs @@ -0,0 +1,262 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::sync::Arc; + +use databend_common_base::base::tokio; +use databend_common_exception::Result; +use databend_common_expression::block_debug::pretty_format_blocks; +use databend_common_expression::types::DataType; +use databend_common_expression::types::Int32Type; +use databend_common_expression::types::NumberDataType; +use databend_common_expression::Column; +use databend_common_expression::DataBlock; +use databend_common_expression::DataField; +use databend_common_expression::DataSchemaRefExt; +use databend_common_expression::FromData; +use databend_common_expression::SortColumnDescription; +use databend_common_pipeline_transforms::processors::sort::HeapMerger; +use databend_common_pipeline_transforms::processors::sort::SimpleRows; +use databend_common_pipeline_transforms::processors::sort::SortedStream; +use itertools::Itertools; +use rand::rngs::ThreadRng; +use rand::Rng; + +struct TestStream { + data: VecDeque, + rng: ThreadRng, +} + +unsafe impl Send for TestStream {} + +impl TestStream { + fn new(data: VecDeque) -> Self { + Self { + data, + rng: rand::thread_rng(), + } + } +} + +impl SortedStream for TestStream { + fn next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { + // To simulate the real scenario, we randomly decide whether the stream is pending or not. + let pending = self.rng.gen_bool(0.5); + if pending { + Ok((None, true)) + } else { + Ok(( + self.data.pop_front().map(|b| { + let col = b.get_last_column().clone(); + (b, col) + }), + false, + )) + } + } +} + +type TestMerger = HeapMerger, TestStream>; + +fn prepare_input_and_result( + data: Vec>>, + limit: Option, +) -> (Vec>, DataBlock) { + let input = data + .clone() + .into_iter() + .map(|v| { + v.into_iter() + .map(|v| DataBlock::new_from_columns(vec![Int32Type::from_data(v)])) + .collect::>() + }) + .collect::>(); + let result = data + .into_iter() + .flatten() + .flatten() + .sorted() + .take(limit.unwrap_or(usize::MAX)) + .collect::>(); + let result = DataBlock::new_from_columns(vec![Int32Type::from_data(result)]); + + (input, result) +} + +/// Returns (input, expected) +fn basic_test_data(limit: Option) -> (Vec>, DataBlock) { + let data = vec![ + vec![vec![1, 2, 3, 4], vec![4, 5, 6, 7]], + vec![vec![1, 1, 1, 1], vec![1, 10, 100, 2000]], + vec![vec![0, 2, 4, 5]], + ]; + + prepare_input_and_result(data, limit) +} + +/// Returns (input, expected, batch_size, num_merge) +fn random_test_data(rng: &mut ThreadRng) -> (Vec>, DataBlock, Option) { + let random_batch_size = rng.gen_range(1..=10); + let random_num_streams = rng.gen_range(5..=10); + + let random_data = (0..random_num_streams) + .map(|_| { + let random_num_blocks = rng.gen_range(1..=10); + let mut data = (0..random_batch_size * random_num_blocks) + .map(|_| rng.gen_range(0..=1000)) + .collect::>(); + data.sort(); + data.chunks(random_batch_size) + .map(|v| v.to_vec()) + .collect::>() + }) + .collect::>(); + + let num_rows = random_data + .iter() + .map(|v| v.iter().map(|v| v.len()).sum::()) + .sum::(); + let limit = rng.gen_range(0..=num_rows); + let (input, expected) = prepare_input_and_result(random_data, Some(limit)); + (input, expected, Some(limit)) +} + +fn create_test_merger(input: Vec>, limit: Option) -> TestMerger { + let schema = DataSchemaRefExt::create(vec![DataField::new( + "a", + DataType::Number(NumberDataType::Int32), + )]); + let sort_desc = Arc::new(vec![SortColumnDescription { + offset: 0, + asc: true, + nulls_first: true, + is_nullable: false, + }]); + let streams = input + .into_iter() + .map(|v| TestStream::new(v.into_iter().collect::>())) + .collect::>(); + + TestMerger::create(schema, streams, sort_desc, 4, limit) +} + +fn check_result(result: Vec, expected: DataBlock) { + if expected.is_empty() { + if !result.is_empty() && !DataBlock::concat(&result).unwrap().is_empty() { + panic!( + "\nexpected empty block, but got:\n {}", + pretty_format_blocks(&result).unwrap() + ) + } + return; + } + + let result_rows = result.iter().map(|v| v.num_rows()).sum::(); + let result = pretty_format_blocks(&result).unwrap(); + let expected_rows = expected.num_rows(); + let expected = pretty_format_blocks(&[expected]).unwrap(); + assert_eq!( + expected, result, + "\nexpected (num_rows = {}):\n{}\nactual (num_rows = {}):\n{}", + expected_rows, expected, result_rows, result + ); +} + +fn test(mut merger: TestMerger, expected: DataBlock) -> Result<()> { + let mut result = Vec::new(); + + while !merger.is_finished() { + if let Some(block) = merger.next_block()? { + result.push(block); + } + } + + check_result(result, expected); + + Ok(()) +} + +async fn async_test(mut merger: TestMerger, expected: DataBlock) -> Result<()> { + let mut result = Vec::new(); + + while !merger.is_finished() { + if let Some(block) = merger.async_next_block().await? { + result.push(block); + } + } + check_result(result, expected); + + Ok(()) +} + +fn test_basic(limit: Option) -> Result<()> { + let (input, expected) = basic_test_data(limit); + let merger = create_test_merger(input, limit); + test(merger, expected) +} + +async fn async_test_basic(limit: Option) -> Result<()> { + let (input, expected) = basic_test_data(limit); + let merger = create_test_merger(input, limit); + async_test(merger, expected).await +} + +#[test] +fn test_basic_with_limit() -> Result<()> { + test_basic(None)?; + test_basic(Some(0))?; + test_basic(Some(1))?; + test_basic(Some(5))?; + test_basic(Some(20))?; + test_basic(Some(21))?; + test_basic(Some(1000000)) +} + +#[tokio::test(flavor = "multi_thread")] +async fn async_test_basic_with_limit() -> Result<()> { + async_test_basic(None).await?; + async_test_basic(Some(0)).await?; + async_test_basic(Some(1)).await?; + async_test_basic(Some(5)).await?; + async_test_basic(Some(20)).await?; + async_test_basic(Some(21)).await?; + async_test_basic(Some(1000000)).await +} + +#[test] +fn test_fuzz() -> Result<()> { + let mut rng = rand::thread_rng(); + + for _ in 0..10 { + let (input, expected, limit) = random_test_data(&mut rng); + let merger = create_test_merger(input, limit); + test(merger, expected)?; + } + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread")] +async fn test_fuzz_async() -> Result<()> { + let mut rng = rand::thread_rng(); + + for _ in 0..10 { + let (input, expected, limit) = random_test_data(&mut rng); + let merger = create_test_merger(input, limit); + async_test(merger, expected).await?; + } + + Ok(()) +} diff --git a/src/query/service/src/pipelines/builders/builder_sort.rs b/src/query/service/src/pipelines/builders/builder_sort.rs index b921b8197695..47f7a087ae75 100644 --- a/src/query/service/src/pipelines/builders/builder_sort.rs +++ b/src/query/service/src/pipelines/builders/builder_sort.rs @@ -15,15 +15,12 @@ use std::sync::Arc; use databend_common_exception::Result; -use databend_common_expression::types::DataType; -use databend_common_expression::DataField; -use databend_common_expression::DataSchema; use databend_common_expression::DataSchemaRef; -use databend_common_expression::DataSchemaRefExt; use databend_common_expression::SortColumnDescription; use databend_common_pipeline_core::processors::ProcessorPtr; use databend_common_pipeline_core::query_spill_prefix; use databend_common_pipeline_core::Pipeline; +use databend_common_pipeline_transforms::processors::sort::utils::add_order_field; use databend_common_pipeline_transforms::processors::try_add_multi_sort_merge; use databend_common_pipeline_transforms::processors::ProcessorProfileWrapper; use databend_common_pipeline_transforms::processors::TransformSortMergeBuilder; @@ -76,24 +73,24 @@ impl PipelineBuilder { } } - let input_schema = sort.output_schema()?; + let plan_schema = sort.output_schema()?; let sort_desc = sort .order_by .iter() .map(|desc| { - let offset = input_schema.index_of(&desc.order_by.to_string())?; + let offset = plan_schema.index_of(&desc.order_by.to_string())?; Ok(SortColumnDescription { offset, asc: desc.asc, nulls_first: desc.nulls_first, - is_nullable: input_schema.field(offset).is_nullable(), // This information is not needed here. + is_nullable: plan_schema.field(offset).is_nullable(), // This information is not needed here. }) }) .collect::>>()?; self.build_sort_pipeline( - input_schema, + plan_schema, sort_desc, sort.plan_id, sort.limit, @@ -103,7 +100,7 @@ impl PipelineBuilder { pub(crate) fn build_sort_pipeline( &mut self, - input_schema: DataSchemaRef, + plan_schema: DataSchemaRef, sort_desc: Vec, plan_id: u32, limit: Option, @@ -124,7 +121,7 @@ impl PipelineBuilder { }; let mut builder = - SortPipelineBuilder::create(self.ctx.clone(), input_schema.clone(), sort_desc.clone()) + SortPipelineBuilder::create(self.ctx.clone(), plan_schema.clone(), sort_desc.clone()) .with_partial_block_size(block_size) .with_final_block_size(block_size) .with_limit(limit) @@ -139,7 +136,7 @@ impl PipelineBuilder { if self.main_pipeline.output_len() > 1 { try_add_multi_sort_merge( &mut self.main_pipeline, - input_schema, + plan_schema, block_size, limit, sort_desc, @@ -293,11 +290,17 @@ impl SortPipelineBuilder { let may_spill = max_memory_usage != 0 && bytes_limit_per_proc != 0; + let sort_merge_output_schema = if output_order_col || may_spill { + add_order_field(self.schema.clone(), &self.sort_desc) + } else { + self.schema.clone() + }; + pipeline.add_transform(|input, output| { let builder = TransformSortMergeBuilder::create( input, output, - self.schema.clone(), + sort_merge_output_schema.clone(), self.sort_desc.clone(), self.partial_block_size, ) @@ -320,18 +323,8 @@ impl SortPipelineBuilder { })?; if may_spill { + let schema = add_order_field(sort_merge_output_schema.clone(), &self.sort_desc); let config = SpillerConfig::create(query_spill_prefix(&self.ctx.get_tenant())); - // The input of the processor must contain an order column. - let schema = if let Some(f) = self.schema.fields.last() && f.name() == "_order_col" { - self.schema.clone() - } else { - let mut fields = self.schema.fields().clone(); - fields.push(DataField::new( - "_order_col", - order_column_type(&self.sort_desc, &self.schema), - )); - DataSchemaRefExt::create(fields) - }; pipeline.add_transform(|input, output| { let op = DataOperator::instance().operator(); let spiller = @@ -360,7 +353,7 @@ impl SortPipelineBuilder { // Multi-pipelines merge sort try_add_multi_sort_merge( pipeline, - self.schema, + self.schema.clone(), self.final_block_size, self.limit, self.sort_desc, @@ -372,17 +365,3 @@ impl SortPipelineBuilder { Ok(()) } } - -fn order_column_type(desc: &[SortColumnDescription], schema: &DataSchema) -> DataType { - debug_assert!(!desc.is_empty()); - if desc.len() == 1 { - let order_by_field = schema.field(desc[0].offset); - if matches!( - order_by_field.data_type(), - DataType::Number(_) | DataType::Date | DataType::Timestamp | DataType::String - ) { - return order_by_field.data_type().clone(); - } - } - DataType::String -} diff --git a/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs b/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs index 4b1f6b42d92c..1f71989ee29b 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_sort_spill.rs @@ -13,20 +13,18 @@ // limitations under the License. use std::any::Any; -use std::cmp::Reverse; -use std::collections::BinaryHeap; use std::collections::VecDeque; use std::marker::PhantomData; use std::sync::Arc; use std::time::Instant; -use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberType; use databend_common_expression::with_number_mapped_type; use databend_common_expression::BlockMetaInfoDowncast; +use databend_common_expression::Column; use databend_common_expression::DataBlock; use databend_common_expression::DataSchemaRef; use databend_common_expression::SortColumnDescription; @@ -40,14 +38,14 @@ use databend_common_pipeline_core::processors::Event; use databend_common_pipeline_core::processors::InputPort; use databend_common_pipeline_core::processors::OutputPort; use databend_common_pipeline_core::processors::Processor; -use databend_common_pipeline_transforms::processors::sort::utils::find_bigger_child_of_root; use databend_common_pipeline_transforms::processors::sort::CommonRows; -use databend_common_pipeline_transforms::processors::sort::Cursor; use databend_common_pipeline_transforms::processors::sort::DateRows; +use databend_common_pipeline_transforms::processors::sort::HeapMerger; use databend_common_pipeline_transforms::processors::sort::Rows; use databend_common_pipeline_transforms::processors::sort::SimpleRows; use databend_common_pipeline_transforms::processors::sort::SortSpillMeta; use databend_common_pipeline_transforms::processors::sort::SortSpillMetaWithParams; +use databend_common_pipeline_transforms::processors::sort::SortedStream; use databend_common_pipeline_transforms::processors::sort::StringRows; use databend_common_pipeline_transforms::processors::sort::TimestampRows; @@ -326,15 +324,16 @@ where R: Rows + Sync + Send + 'static streams.push(stream); } - let mut merger = Merger::::create( + let mut merger = HeapMerger::::create( self.schema.clone(), streams, self.sort_desc.clone(), self.batch_size, + None, ); let mut spilled = VecDeque::new(); - while let Some(block) = merger.next().await? { + while let Some(block) = merger.async_next_block().await? { let ins = Instant::now(); let (location, bytes) = self.spiller.spill_block(block).await?; @@ -347,6 +346,8 @@ where R: Rows + Sync + Send + 'static spilled.push_back(location); } + + debug_assert!(merger.is_finished()); self.unmerged_blocks.push_back(spilled); Ok(()) @@ -358,8 +359,9 @@ enum BlockStream { Block(Option), } -impl BlockStream { - async fn next(&mut self) -> Result> { +#[async_trait::async_trait] +impl SortedStream for BlockStream { + async fn async_next(&mut self) -> Result<(Option<(DataBlock, Column)>, bool)> { let block = match self { BlockStream::Block(block) => block.take(), BlockStream::Spilled((files, spiller)) => { @@ -380,149 +382,13 @@ impl BlockStream { } } }; - Ok(block) - } -} - -/// A merge sort operator to merge multiple sorted streams. -/// -/// TODO: reuse this operator in other places such as `TransformMultiSortMerge` and `TransformSortMerge`. -struct Merger { - schema: DataSchemaRef, - sort_desc: Arc>, - unsorted_streams: Vec, - heap: BinaryHeap>>, - buffer: Vec, - pending_stream: VecDeque, - batch_size: usize, -} - -impl Merger { - fn create( - schema: DataSchemaRef, - streams: Vec, - sort_desc: Arc>, - batch_size: usize, - ) -> Self { - // We only create a merger when there are at least two streams. - debug_assert!(streams.len() > 1); - let heap = BinaryHeap::with_capacity(streams.len()); - let buffer = vec![DataBlock::empty_with_schema(schema.clone()); streams.len()]; - let pending_stream = (0..streams.len()).collect(); - - Self { - schema, - unsorted_streams: streams, - heap, - buffer, - batch_size, - sort_desc, - pending_stream, - } - } - - // This method can only be called when there is no data of the stream in the heap. - async fn poll_pending_stream(&mut self) -> Result<()> { - while let Some(i) = self.pending_stream.pop_front() { - debug_assert!(self.buffer[i].is_empty()); - if let Some(block) = self.unsorted_streams[i].next().await? { - let order_col = block.columns().last().unwrap().value.as_column().unwrap(); - let rows = R::from_column(order_col.clone(), &self.sort_desc) - .ok_or_else(|| ErrorCode::BadDataValueType("Order column type mismatched."))?; - let cursor = Cursor::new(i, rows); - self.heap.push(Reverse(cursor)); - self.buffer[i] = block; - } - } - Ok(()) - } - - async fn next(&mut self) -> Result> { - if !self.pending_stream.is_empty() { - self.poll_pending_stream().await?; - } - - if self.heap.is_empty() { - return Ok(None); - } - - let mut num_rows = 0; - - // (input_index, row_start, count) - let mut output_indices = Vec::new(); - let mut temp_sorted_blocks = Vec::new(); - - while let Some(Reverse(cursor)) = self.heap.peek() { - let mut cursor = cursor.clone(); - if self.heap.len() == 1 { - let start = cursor.row_index; - let count = (cursor.num_rows() - start).min(self.batch_size - num_rows); - num_rows += count; - cursor.row_index += count; - output_indices.push((cursor.input_index, start, count)); - } else { - let next_cursor = &find_bigger_child_of_root(&self.heap).0; - if cursor.last().le(&next_cursor.current()) { - // Short Path: - // If the last row of current block is smaller than the next cursor, - // we can drain the whole block. - let start = cursor.row_index; - let count = (cursor.num_rows() - start).min(self.batch_size - num_rows); - num_rows += count; - cursor.row_index += count; - output_indices.push((cursor.input_index, start, count)); - } else { - // We copy current cursor for advancing, - // and we will use this copied cursor to update the top of the heap at last - // (let heap adjust itself without popping and pushing any element). - let start = cursor.row_index; - while !cursor.is_finished() - && cursor.le(next_cursor) - && num_rows < self.batch_size - { - // If the cursor is smaller than the next cursor, don't need to push the cursor back to the heap. - num_rows += 1; - cursor.advance(); - } - output_indices.push((cursor.input_index, start, cursor.row_index - start)); - } - } - - if !cursor.is_finished() { - // Update the top of the heap. - // `self.heap.peek_mut` will return a `PeekMut` object which allows us to modify the top element of the heap. - // The heap will adjust itself automatically when the `PeekMut` object is dropped (RAII). - self.heap.peek_mut().unwrap().0 = cursor; - } else { - // Pop the current `cursor`. - self.heap.pop(); - // We have read all rows of this block, need to release the old memory and read a new one. - let temp_block = DataBlock::take_by_slices_limit_from_blocks( - &self.buffer, - &output_indices, - None, - ); - self.buffer[cursor.input_index] = DataBlock::empty_with_schema(self.schema.clone()); - temp_sorted_blocks.push(temp_block); - output_indices.clear(); - self.pending_stream.push_back(cursor.input_index); - self.poll_pending_stream().await?; - } - - if num_rows == self.batch_size { - break; - } - } - - if !output_indices.is_empty() { - let block = - DataBlock::take_by_slices_limit_from_blocks(&self.buffer, &output_indices, None); - temp_sorted_blocks.push(block); - } - - let block = DataBlock::concat(&temp_sorted_blocks)?; - debug_assert!(block.num_rows() <= self.batch_size); - Ok(Some(block)) + Ok(( + block.map(|b| { + let col = b.get_last_column().clone(); + (b, col) + }), + false, + )) } } @@ -612,6 +478,7 @@ mod tests { use databend_common_pipeline_core::processors::InputPort; use databend_common_pipeline_core::processors::OutputPort; use databend_common_pipeline_transforms::processors::sort::SimpleRows; + use databend_common_pipeline_transforms::processors::sort::SortedStream; use databend_common_storage::DataOperator; use itertools::Itertools; use rand::rngs::ThreadRng; @@ -737,7 +604,8 @@ mod tests { )); let mut result = Vec::new(); - while let Some(block) = block_stream.next().await? { + + while let (Some((block, _)), _) = block_stream.async_next().await? { result.push(block); } diff --git a/src/query/sql/src/executor/physical_plans/physical_sort.rs b/src/query/sql/src/executor/physical_plans/physical_sort.rs index 67c940b8af88..0a9714ef869f 100644 --- a/src/query/sql/src/executor/physical_plans/physical_sort.rs +++ b/src/query/sql/src/executor/physical_plans/physical_sort.rs @@ -18,6 +18,7 @@ use databend_common_expression::DataField; use databend_common_expression::DataSchema; use databend_common_expression::DataSchemaRef; use databend_common_expression::DataSchemaRefExt; +use databend_common_pipeline_transforms::processors::sort::utils::ORDER_COL_NAME; use itertools::Itertools; use crate::executor::explain::PlanStatsInfo; @@ -65,7 +66,7 @@ impl Sort { if matches!(self.after_exchange, Some(true)) { // If the plan is after exchange plan in cluster mode, // the order column is at the last of the input schema. - debug_assert_eq!(fields.last().unwrap().name(), "_order_col"); + debug_assert_eq!(fields.last().unwrap().name(), ORDER_COL_NAME); debug_assert_eq!( fields.last().unwrap().data_type(), &self.order_col_type(&input_schema)? @@ -88,7 +89,7 @@ impl Sort { // If the plan is before exchange plan in cluster mode, // the order column should be added to the output schema. fields.push(DataField::new( - "_order_col", + ORDER_COL_NAME, self.order_col_type(&input_schema)?, )); }