From e84b1eefdf509f17c8757bb37a2b9a69b6778cd7 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 30 Apr 2024 18:10:11 +0800 Subject: [PATCH] perf: optimize `count(*)` (#3845) * perf: optimize count(*) Signed-off-by: Ruihang Xia * fallback to count(1) for temporary table Signed-off-by: Ruihang Xia * handle alias expr in range plan Signed-off-by: Ruihang Xia * handle subquery alias Signed-off-by: Ruihang Xia * rename file Signed-off-by: Ruihang Xia --------- Signed-off-by: Ruihang Xia --- src/query/src/optimizer.rs | 3 +- src/query/src/optimizer/count_wildcard.rs | 156 ++++++++++++++++++ src/query/src/query_engine/state.rs | 12 +- src/query/src/range_select/plan.rs | 16 +- .../common/range/special_aggr.result | 25 ++- .../standalone/common/range/special_aggr.sql | 10 +- .../common/tql-explain-analyze/explain.result | 2 +- 7 files changed, 208 insertions(+), 16 deletions(-) create mode 100644 src/query/src/optimizer/count_wildcard.rs diff --git a/src/query/src/optimizer.rs b/src/query/src/optimizer.rs index 26b906fd985d..e6a971417c23 100644 --- a/src/query/src/optimizer.rs +++ b/src/query/src/optimizer.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod count_wildcard; pub mod order_hint; pub mod remove_duplicate; pub mod string_normalization; @@ -27,7 +28,7 @@ use crate::QueryEngineContext; /// [`ExtensionAnalyzerRule`]s transform [`LogicalPlan`]s in some way to make /// the plan valid prior to the rest of the DataFusion optimization process. -/// It's an extension of datafusion [`AnalyzerRule`]s but accepts [`QueryEngineContext` as the second parameter. +/// It's an extension of datafusion [`AnalyzerRule`]s but accepts [`QueryEngineContext`] as the second parameter. pub trait ExtensionAnalyzerRule { /// Rewrite `plan` fn analyze( diff --git a/src/query/src/optimizer/count_wildcard.rs b/src/query/src/optimizer/count_wildcard.rs new file mode 100644 index 000000000000..359d333c25c1 --- /dev/null +++ b/src/query/src/optimizer/count_wildcard.rs @@ -0,0 +1,156 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion::datasource::DefaultTableSource; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; +use datafusion_common::Result as DataFusionResult; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, WindowFunction}; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::{col, lit, Expr, LogicalPlan, WindowFunctionDefinition}; +use datafusion_optimizer::utils::NamePreserver; +use datafusion_optimizer::AnalyzerRule; +use table::table::adapter::DfTableProviderAdapter; + +/// A replacement to DataFusion's [`CountWildcardRule`]. This rule +/// would prefer to use TIME INDEX for counting wildcard as it's +/// faster to read comparing to PRIMARY KEYs. +/// +/// [`CountWildcardRule`]: datafusion::optimizer::analyzer::CountWildcardRule +pub struct CountWildcardToTimeIndexRule; + +impl AnalyzerRule for CountWildcardToTimeIndexRule { + fn name(&self) -> &str { + "count_wildcard_to_time_index_rule" + } + + fn analyze( + &self, + plan: LogicalPlan, + _config: &datafusion::config::ConfigOptions, + ) -> DataFusionResult { + plan.transform_down_with_subqueries(&Self::analyze_internal) + .data() + } +} + +impl CountWildcardToTimeIndexRule { + fn analyze_internal(plan: LogicalPlan) -> DataFusionResult> { + let name_preserver = NamePreserver::new(&plan); + let new_arg = if let Some(time_index) = Self::try_find_time_index_col(&plan) { + vec![col(time_index)] + } else { + vec![lit(COUNT_STAR_EXPANSION)] + }; + plan.map_expressions(|expr| { + let original_name = name_preserver.save(&expr)?; + let transformed_expr = expr.transform_up_mut(&mut |expr| match expr { + Expr::WindowFunction(mut window_function) + if Self::is_count_star_window_aggregate(&window_function) => + { + window_function.args.clone_from(&new_arg); + Ok(Transformed::yes(Expr::WindowFunction(window_function))) + } + Expr::AggregateFunction(mut aggregate_function) + if Self::is_count_star_aggregate(&aggregate_function) => + { + aggregate_function.args.clone_from(&new_arg); + Ok(Transformed::yes(Expr::AggregateFunction( + aggregate_function, + ))) + } + _ => Ok(Transformed::no(expr)), + })?; + transformed_expr.map_data(|data| original_name.restore(data)) + }) + } + + fn try_find_time_index_col(plan: &LogicalPlan) -> Option { + let mut finder = TimeIndexFinder::default(); + // Safety: `TimeIndexFinder` won't throw error. + plan.visit(&mut finder).unwrap(); + finder.time_index + } +} + +/// Utility functions from the original rule. +impl CountWildcardToTimeIndexRule { + fn is_wildcard(expr: &Expr) -> bool { + matches!(expr, Expr::Wildcard { qualifier: None }) + } + + fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { + matches!( + &aggregate_function.func_def, + AggregateFunctionDefinition::BuiltIn( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ) + ) && aggregate_function.args.len() == 1 + && Self::is_wildcard(&aggregate_function.args[0]) + } + + fn is_count_star_window_aggregate(window_function: &WindowFunction) -> bool { + matches!( + &window_function.fun, + WindowFunctionDefinition::AggregateFunction( + datafusion_expr::aggregate_function::AggregateFunction::Count, + ) + ) && window_function.args.len() == 1 + && Self::is_wildcard(&window_function.args[0]) + } +} + +#[derive(Default)] +struct TimeIndexFinder { + time_index: Option, + table_alias: Option, +} + +impl TreeNodeVisitor for TimeIndexFinder { + type Node = LogicalPlan; + + fn f_down(&mut self, node: &Self::Node) -> DataFusionResult { + if let LogicalPlan::SubqueryAlias(subquery_alias) = node { + self.table_alias = Some(subquery_alias.alias.to_string()); + } + + if let LogicalPlan::TableScan(table_scan) = &node { + if let Some(source) = table_scan + .source + .as_any() + .downcast_ref::() + { + if let Some(adapter) = source + .table_provider + .as_any() + .downcast_ref::() + { + let table_info = adapter.table().table_info(); + let col_name = table_info.meta.schema.timestamp_column().map(|c| &c.name); + let table_name = self.table_alias.as_ref().unwrap_or(&table_info.name); + self.time_index = col_name.map(|s| format!("{}.{}", table_name, s)); + + return Ok(TreeNodeRecursion::Stop); + } + } + } + + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, _node: &Self::Node) -> DataFusionResult { + Ok(TreeNodeRecursion::Stop) + } +} diff --git a/src/query/src/query_engine/state.rs b/src/query/src/query_engine/state.rs index bf2c01de33ed..3fa44a280617 100644 --- a/src/query/src/query_engine/state.rs +++ b/src/query/src/query_engine/state.rs @@ -43,6 +43,7 @@ use table::table::adapter::DfTableProviderAdapter; use table::TableRef; use crate::dist_plan::{DistExtensionPlanner, DistPlannerAnalyzer}; +use crate::optimizer::count_wildcard::CountWildcardToTimeIndexRule; use crate::optimizer::order_hint::OrderHintRule; use crate::optimizer::remove_duplicate::RemoveDuplicate; use crate::optimizer::string_normalization::StringNormalizationRule; @@ -89,18 +90,27 @@ impl QueryEngineState { let session_config = SessionConfig::new().with_create_default_catalog_and_schema(false); // Apply extension rules let mut extension_rules = Vec::new(); + // The [`TypeConversionRule`] must be at first extension_rules.insert(0, Arc::new(TypeConversionRule) as _); + // Apply the datafusion rules let mut analyzer = Analyzer::new(); analyzer.rules.insert(0, Arc::new(StringNormalizationRule)); + + // Use our custom rule instead to optimize the count(*) query Self::remove_analyzer_rule(&mut analyzer.rules, CountWildcardRule {}.name()); - analyzer.rules.insert(0, Arc::new(CountWildcardRule {})); + analyzer + .rules + .insert(0, Arc::new(CountWildcardToTimeIndexRule)); + if with_dist_planner { analyzer.rules.push(Arc::new(DistPlannerAnalyzer)); } + let mut optimizer = Optimizer::new(); optimizer.rules.push(Arc::new(OrderHintRule)); + // add physical optimizer let mut physical_optimizer = PhysicalOptimizer::new(); physical_optimizer.rules.push(Arc::new(RemoveDuplicate)); diff --git a/src/query/src/range_select/plan.rs b/src/query/src/range_select/plan.rs index d31097efc0e9..73be735d4de3 100644 --- a/src/query/src/range_select/plan.rs +++ b/src/query/src/range_select/plan.rs @@ -667,7 +667,13 @@ impl RangeSelect { .range_expr .iter() .map(|range_fn| { - let expr = match &range_fn.expr { + let name = range_fn.expr.display_name()?; + let range_expr = match &range_fn.expr { + Expr::Alias(expr) => expr.expr.as_ref(), + others => others, + }; + + let expr = match &range_expr { Expr::AggregateFunction( aggr @ datafusion_expr::expr::AggregateFunction { func_def: @@ -778,7 +784,7 @@ impl RangeSelect { &input_phy_exprs, &order_by, &input_schema, - range_fn.expr.display_name()?, + name, false, ), AggregateFunctionDefinition::UDF(fun) => create_aggr_udf_expr( @@ -787,7 +793,7 @@ impl RangeSelect { &[], &[], &input_schema, - range_fn.expr.display_name()?, + name, false, ), f => Err(DataFusionError::NotImplemented(format!( @@ -796,8 +802,8 @@ impl RangeSelect { } } _ => Err(DataFusionError::Plan(format!( - "Unexpected Expr:{} in RangeSelect", - range_fn.expr.display_name()? + "Unexpected Expr: {} in RangeSelect", + range_fn.expr.canonical_name() ))), }?; let args = expr.expressions(); diff --git a/tests/cases/standalone/common/range/special_aggr.result b/tests/cases/standalone/common/range/special_aggr.result index 6fab9998d4e1..2240d744389d 100644 --- a/tests/cases/standalone/common/range/special_aggr.result +++ b/tests/cases/standalone/common/range/special_aggr.result @@ -143,7 +143,7 @@ SELECT ts, host, first_value(addon ORDER BY val ASC, ts ASC) RANGE '5s', last_va | 1970-01-01T00:00:20 | host2 | 28 | 30 | +---------------------+-------+---------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------+ -SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; +---------------------+-------+--------------------------+ | ts | host | COUNT(host.val) RANGE 5s | @@ -160,7 +160,7 @@ SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; | 1970-01-01T00:00:20 | host2 | 2 | +---------------------+-------+--------------------------+ -SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(distinct val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; +---------------------+-------+-----------------------------------+ | ts | host | COUNT(DISTINCT host.val) RANGE 5s | @@ -177,7 +177,7 @@ SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY hos | 1970-01-01T00:00:20 | host2 | 2 | +---------------------+-------+-----------------------------------+ -SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(*) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; +---------------------+-------+-------------------+ | ts | host | COUNT(*) RANGE 5s | @@ -194,7 +194,24 @@ SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; | 1970-01-01T00:00:20 | host2 | 3 | +---------------------+-------+-------------------+ -SELECT ts, host, count(distinct *) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(1) RANGE '5s' as abc FROM host ALIGN '5s' ORDER BY host, ts; + ++---------------------+-------+-----+ +| ts | host | abc | ++---------------------+-------+-----+ +| 1970-01-01T00:00:00 | host1 | 3 | +| 1970-01-01T00:00:05 | host1 | 3 | +| 1970-01-01T00:00:10 | host1 | 3 | +| 1970-01-01T00:00:15 | host1 | 3 | +| 1970-01-01T00:00:20 | host1 | 3 | +| 1970-01-01T00:00:00 | host2 | 3 | +| 1970-01-01T00:00:05 | host2 | 3 | +| 1970-01-01T00:00:10 | host2 | 3 | +| 1970-01-01T00:00:15 | host2 | 3 | +| 1970-01-01T00:00:20 | host2 | 3 | ++---------------------+-------+-----+ + +SELECT ts, host, count(distinct *) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; +---------------------+-------+----------------------------+ | ts | host | COUNT(DISTINCT *) RANGE 5s | diff --git a/tests/cases/standalone/common/range/special_aggr.sql b/tests/cases/standalone/common/range/special_aggr.sql index bf3cd9e29c6d..34a6b691443c 100644 --- a/tests/cases/standalone/common/range/special_aggr.sql +++ b/tests/cases/standalone/common/range/special_aggr.sql @@ -58,13 +58,15 @@ SELECT ts, host, first_value(addon ORDER BY val ASC NULLS FIRST) RANGE '5s', las SELECT ts, host, first_value(addon ORDER BY val ASC, ts ASC) RANGE '5s', last_value(addon ORDER BY val ASC, ts ASC) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; -SELECT ts, host, count(val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; -SELECT ts, host, count(distinct val) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(distinct val) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; -SELECT ts, host, count(*) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(*) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; -SELECT ts, host, count(distinct *) RANGE '5s'FROM host ALIGN '5s' ORDER BY host, ts; +SELECT ts, host, count(1) RANGE '5s' as abc FROM host ALIGN '5s' ORDER BY host, ts; + +SELECT ts, host, count(distinct *) RANGE '5s' FROM host ALIGN '5s' ORDER BY host, ts; -- Test error first_value/last_value diff --git a/tests/cases/standalone/common/tql-explain-analyze/explain.result b/tests/cases/standalone/common/tql-explain-analyze/explain.result index 024c3f80de60..a49624011d6c 100644 --- a/tests/cases/standalone/common/tql-explain-analyze/explain.result +++ b/tests/cases/standalone/common/tql-explain-analyze/explain.result @@ -86,7 +86,7 @@ TQL EXPLAIN VERBOSE (0, 10, '5s') test; |_|_Filter: test.j >= TimestampMillisecond(-300000, None) AND test.j <= TimestampMillisecond(300000, None)_| |_|_TableScan: test_| | logical_plan after apply_function_rewrites_| SAME TEXT AS ABOVE_| -| logical_plan after count_wildcard_rule_| SAME TEXT AS ABOVE_| +| logical_plan after count_wildcard_to_time_index_rule_| SAME TEXT AS ABOVE_| | logical_plan after StringNormalizationRule_| SAME TEXT AS ABOVE_| | logical_plan after inline_table_scan_| SAME TEXT AS ABOVE_| | logical_plan after type_coercion_| SAME TEXT AS ABOVE_|