Skip to content

Commit

Permalink
feat: Merge sort Logical plan (#4768)
Browse files Browse the repository at this point in the history
* feat(WIP): MergeSort

* wip

* feat: MergeSort LogicalPlan

* update sqlness result

* Apply suggestions from code review

Co-authored-by: Lei, HUANG <[email protected]>

* refactor: per review advice

* refactor: more per review

* chore: per review

---------

Co-authored-by: Ruihang Xia <[email protected]>
Co-authored-by: Lei, HUANG <[email protected]>
  • Loading branch information
3 people authored Oct 9, 2024
1 parent 0dd11f5 commit a8ed3db
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 10 deletions.
3 changes: 2 additions & 1 deletion src/query/src/dist_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
mod analyzer;
mod commutativity;
mod merge_scan;
mod merge_sort;
mod planner;

pub use analyzer::DistPlannerAnalyzer;
pub use merge_scan::{MergeScanExec, MergeScanLogicalPlan};
pub use planner::DistExtensionPlanner;
pub use planner::{DistExtensionPlanner, MergeSortExtensionPlanner};
4 changes: 2 additions & 2 deletions src/query/src/dist_plan/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ impl PlanRewriter {
{
return true;
}

match Categorizer::check_plan(plan, self.partition_cols.clone()) {
Commutativity::Commutative => {}
Commutativity::PartialCommutative => {
Expand Down Expand Up @@ -265,9 +264,10 @@ impl PlanRewriter {

// add merge scan as the new root
let mut node = MergeScanLogicalPlan::new(on_node, false).into_logical_plan();

// expand stages
for new_stage in self.stage.drain(..) {
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?
node = new_stage.with_new_exprs(new_stage.expressions(), vec![node.clone()])?;
}
self.set_expanded();

Expand Down
9 changes: 6 additions & 3 deletions src/query/src/dist_plan/commutativity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use promql::extension_plan::{
EmptyMetric, InstantManipulate, RangeManipulate, SeriesDivide, SeriesNormalize,
};

use crate::dist_plan::merge_sort::{merge_sort_transformer, MergeSortLogicalPlan};
use crate::dist_plan::MergeScanLogicalPlan;

#[allow(dead_code)]
Expand Down Expand Up @@ -68,8 +69,9 @@ impl Categorizer {
}

// sort plan needs to consider column priority
// We can implement a merge-sort on partial ordered data
Commutativity::PartialCommutative
// Change Sort to MergeSort which assumes the input streams are already sorted hence can be more efficient
// We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
Commutativity::ConditionalCommutative(Some(Arc::new(merge_sort_transformer)))
}
LogicalPlan::Join(_) => Commutativity::NonCommutative,
LogicalPlan::CrossJoin(_) => Commutativity::NonCommutative,
Expand Down Expand Up @@ -118,7 +120,8 @@ impl Categorizer {
|| name == SeriesNormalize::name()
|| name == RangeManipulate::name()
|| name == SeriesDivide::name()
|| name == MergeScanLogicalPlan::name() =>
|| name == MergeScanLogicalPlan::name()
|| name == MergeSortLogicalPlan::name() =>
{
Commutativity::Unimplemented
}
Expand Down
8 changes: 8 additions & 0 deletions src/query/src/dist_plan/merge_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,14 @@ impl MergeScanExec {
pub fn sub_stage_metrics(&self) -> Vec<RecordBatchMetrics> {
self.sub_stage_metrics.lock().unwrap().clone()
}

pub fn partition_count(&self) -> usize {
self.target_partition
}

pub fn region_count(&self) -> usize {
self.regions.len()
}
}

impl ExecutionPlan for MergeScanExec {
Expand Down
124 changes: 124 additions & 0 deletions src/query/src/dist_plan/merge_sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Merge sort logical plan for distributed query execution, roughly corresponding to the
//! `SortPreservingMergeExec` operator in datafusion
//!

use std::fmt;
use std::sync::Arc;

use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Expr, Extension, LogicalPlan, UserDefinedLogicalNodeCore};

/// MergeSort Logical Plan, have same field as `Sort`, but indicate it is a merge sort,
/// which assume each input partition is a sorted stream, and will use `SortPreserveingMergeExec`
/// to merge them into a single sorted stream.
#[derive(Hash, PartialEq, Eq, Clone)]
pub struct MergeSortLogicalPlan {
pub expr: Vec<Expr>,
pub input: Arc<LogicalPlan>,
pub fetch: Option<usize>,
}

impl fmt::Debug for MergeSortLogicalPlan {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
UserDefinedLogicalNodeCore::fmt_for_explain(self, f)
}
}

impl MergeSortLogicalPlan {
pub fn new(input: Arc<LogicalPlan>, expr: Vec<Expr>, fetch: Option<usize>) -> Self {
Self { input, expr, fetch }
}

pub fn name() -> &'static str {
"MergeSort"
}

/// Create a [`LogicalPlan::Extension`] node from this merge sort plan
pub fn into_logical_plan(self) -> LogicalPlan {
LogicalPlan::Extension(Extension {
node: Arc::new(self),
})
}

/// Convert self to a [`Sort`] logical plan with same input and expressions
pub fn into_sort(self) -> LogicalPlan {
LogicalPlan::Sort(datafusion::logical_expr::Sort {
input: self.input.clone(),
expr: self.expr,
fetch: self.fetch,
})
}
}

impl UserDefinedLogicalNodeCore for MergeSortLogicalPlan {
fn name(&self) -> &str {
Self::name()
}

// Allow optimization here
fn inputs(&self) -> Vec<&LogicalPlan> {
vec![self.input.as_ref()]
}

fn schema(&self) -> &datafusion_common::DFSchemaRef {
self.input.schema()
}

// Allow further optimization
fn expressions(&self) -> Vec<datafusion_expr::Expr> {
self.expr.clone()
}

fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "MergeSort: ")?;
for (i, expr_item) in self.expr.iter().enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{expr_item}")?;
}
if let Some(a) = self.fetch {
write!(f, ", fetch={a}")?;
}
Ok(())
}

fn with_exprs_and_inputs(
&self,
exprs: Vec<datafusion::prelude::Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<Self> {
let mut zelf = self.clone();
zelf.expr = exprs;
zelf.input = Arc::new(inputs.pop().ok_or_else(|| {
DataFusionError::Internal("Expected exactly one input with MergeSort".to_string())
})?);
Ok(zelf)
}
}

/// Turn `Sort` into `MergeSort` if possible
pub fn merge_sort_transformer(plan: &LogicalPlan) -> Option<LogicalPlan> {
if let LogicalPlan::Sort(sort) = plan {
Some(
MergeSortLogicalPlan::new(sort.input.clone(), sort.expr.clone(), sort.fetch)
.into_logical_plan(),
)
} else {
None
}
}
62 changes: 61 additions & 1 deletion src/query/src/dist_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion::execution::context::SessionState;
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
use datafusion_common::TableReference;
use datafusion_common::{DataFusionError, TableReference};
use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode};
use session::context::QueryContext;
use snafu::{OptionExt, ResultExt};
Expand All @@ -35,9 +35,69 @@ use table::table::adapter::DfTableProviderAdapter;
use table::table_name::TableName;

use crate::dist_plan::merge_scan::{MergeScanExec, MergeScanLogicalPlan};
use crate::dist_plan::merge_sort::MergeSortLogicalPlan;
use crate::error::{CatalogSnafu, TableNotFoundSnafu};
use crate::region_query::RegionQueryHandlerRef;

/// Planner for convert merge sort logical plan to physical plan
/// it is currently a fallback to sort, and doesn't change the execution plan:
/// `MergeSort(MergeScan) -> Sort(MergeScan) - to physical plan -> ...`
/// It should be applied after `DistExtensionPlanner`
///
/// (Later when actually impl this merge sort)
///
/// We should ensure the number of partition is not smaller than the number of region at present. Otherwise this would result in incorrect output.
pub struct MergeSortExtensionPlanner {}

#[async_trait]
impl ExtensionPlanner for MergeSortExtensionPlanner {
async fn plan_extension(
&self,
planner: &dyn PhysicalPlanner,
node: &dyn UserDefinedLogicalNode,
_logical_inputs: &[&LogicalPlan],
physical_inputs: &[Arc<dyn ExecutionPlan>],
session_state: &SessionState,
) -> Result<Option<Arc<dyn ExecutionPlan>>> {
if let Some(merge_sort) = node.as_any().downcast_ref::<MergeSortLogicalPlan>() {
if let LogicalPlan::Extension(ext) = &merge_sort.input.as_ref()
&& ext
.node
.as_any()
.downcast_ref::<MergeScanLogicalPlan>()
.is_some()
{
let merge_scan_exec = physical_inputs
.first()
.and_then(|p| p.as_any().downcast_ref::<MergeScanExec>())
.ok_or(DataFusionError::Internal(format!(
"Expect MergeSort's input is a MergeScanExec, found {:?}",
physical_inputs
)))?;

let partition_cnt = merge_scan_exec.partition_count();
let region_cnt = merge_scan_exec.region_count();
// if partition >= region, we know that every partition stream of merge scan is ordered
// and we only need to do a merge sort, otherwise fallback to quick sort
let can_merge_sort = partition_cnt >= region_cnt;
if can_merge_sort {
// TODO(discord9): use `SortPreversingMergeExec here`
}
// for now merge sort only exist in logical plan, and have the same effect as `Sort`
// doesn't change the execution plan, this will change in the future
let ret = planner
.create_physical_plan(&merge_sort.clone().into_sort(), session_state)
.await?;
Ok(Some(ret))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}

pub struct DistExtensionPlanner {
catalog_manager: CatalogManagerRef,
region_query_handler: RegionQueryHandlerRef,
Expand Down
3 changes: 2 additions & 1 deletion src/query/src/query_engine/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use promql::extension_plan::PromExtensionPlanner;
use table::table::adapter::DfTableProviderAdapter;
use table::TableRef;

use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer};
use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer, MergeSortExtensionPlanner};
use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule;
use crate::optimizer::parallelize_scan::ParallelizeScan;
use crate::optimizer::remove_duplicate::RemoveDuplicate;
Expand Down Expand Up @@ -295,6 +295,7 @@ impl DfQueryPlanner {
catalog_manager,
region_query_handler,
)));
planners.push(Arc::new(MergeSortExtensionPlanner {}));
}
Self {
physical_planner: DefaultPhysicalPlanner::with_extension_planners(planners),
Expand Down
2 changes: 1 addition & 1 deletion tests/cases/distributed/explain/multi_partitions.result
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ explain SELECT * FROM demo WHERE ts > cast(1000000000 as timestamp) ORDER BY hos
+-+-+
| plan_type_| plan_|
+-+-+
| logical_plan_| Sort: demo.host ASC NULLS LAST_|
| logical_plan_| MergeSort: demo.host ASC NULLS LAST_|
|_|_MergeScan [is_placeholder=false]_|
| physical_plan | SortPreservingMergeExec: [host@0 ASC NULLS LAST]_|
|_|_SortExec: expr=[host@0 ASC NULLS LAST], preserve_partitioning=[true]_|
Expand Down
2 changes: 1 addition & 1 deletion tests/cases/standalone/common/order/order_by.result
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ explain analyze select tag from t where num > 6 order by ts desc limit 2;
+-+-+-+
| 0_| 0_|_GlobalLimitExec: skip=0, fetch=2 REDACTED
|_|_|_SortPreservingMergeExec: [ts@1 DESC] REDACTED
|_|_|_SortExec: TopK(fetch=2), expr=[ts@1 DESC], preserve_partitioning=[true] REDACTED
|_|_|_SortExec: expr=[ts@1 DESC], preserve_partitioning=[true] REDACTED
|_|_|_MergeScanExec: REDACTED
|_|_|_|
| 1_| 0_|_GlobalLimitExec: skip=0, fetch=2 REDACTED
Expand Down

0 comments on commit a8ed3db

Please sign in to comment.