diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 541448ebf1491..dcebbb55fb66d 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -59,7 +59,7 @@ pub fn main() -> Result<()> { // then run the optimizer with our custom rule let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); - let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; + let optimized_plan = optimizer.optimize(analyzed_plan, &config, observe)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index da3c108aa47cd..be209ca02077f 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -100,6 +100,9 @@ pub trait TreeNode: Sized { /// Visit the tree node using the given [`TreeNodeVisitor`], performing a /// depth-first walk of the node and its children. /// + /// See also: + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -144,6 +147,9 @@ pub trait TreeNode: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// + /// See also: + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// /// Consider the following tree structure: /// ```text /// ParentNode @@ -353,13 +359,15 @@ pub trait TreeNode: Sized { } /// Apply the closure `F` to the node's children. + /// + /// See `mutate_children` for rewriting in place fn apply_children Result>( &self, f: &mut F, ) -> Result; - /// Apply transform `F` to the node's children. Note that the transform `F` - /// might have a direction (pre-order or post-order). + /// Apply transform `F` to potentially rewrite the node's children. Note + /// that the transform `F` might have a direction (pre-order or post-order). fn map_children Result>>( self, f: F, @@ -489,6 +497,11 @@ impl Transformed { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } + /// Invokes f(), depending on the value of self.tnr. + /// + /// This is used to conditionally apply a function during a f_up tree + /// traversal, if the result of children traversal was `[`TreeNodeRecursion::Continue`]. + /// /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 4eaaf94ecf5df..e417f5990bb08 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1892,7 +1892,7 @@ impl SessionState { // optimize the child plan, capturing the output of each optimizer let optimized_plan = self.optimizer.optimize( - &analyzed_plan, + analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1922,7 +1922,7 @@ impl SessionState { let analyzed_plan = self.analyzer .execute_and_check(plan, self.options(), |_, _| {})?; - self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) + self.optimizer.optimize(analyzed_plan, self, |_, _| {}) } } diff --git a/datafusion/core/tests/optimizer_integration.rs b/datafusion/core/tests/optimizer_integration.rs index 60010bdddfb82..6e938361ddb48 100644 --- a/datafusion/core/tests/optimizer_integration.rs +++ b/datafusion/core/tests/optimizer_integration.rs @@ -110,7 +110,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index b94cf37c5c12b..744feeac09146 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -337,7 +337,7 @@ mod tests { Operator, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), plan, @@ -377,7 +377,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional AND filter @@ -403,7 +403,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for IN subquery with additional OR filter @@ -429,7 +429,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -457,7 +457,7 @@ mod tests { \n Projection: sq2.c [c:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for nested IN subqueries @@ -486,7 +486,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for filter input modification in case filter not supported @@ -518,7 +518,7 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test multiple correlated subqueries @@ -556,7 +556,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -606,7 +606,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -641,7 +641,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -674,7 +674,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -705,7 +705,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -738,7 +738,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -771,7 +771,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -805,7 +805,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); @@ -862,7 +862,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -895,7 +895,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -961,7 +961,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -999,7 +999,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1029,7 +1029,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1053,7 +1053,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1077,7 +1077,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1106,7 +1106,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1141,7 +1141,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1177,7 +1177,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1223,7 +1223,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1254,7 +1254,7 @@ mod tests { assert_optimized_plan_eq_display_indent( Arc::new(DecorrelatePredicateSubquery::new()), - &plan, + plan, expected, ); Ok(()) @@ -1288,7 +1288,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test recursive correlated subqueries @@ -1331,7 +1331,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters @@ -1361,7 +1361,7 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1386,7 +1386,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for exists subquery with both columns in schema @@ -1404,7 +1404,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for correlated exists subquery not equal @@ -1432,7 +1432,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery less than @@ -1460,7 +1460,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with subquery disjunction @@ -1489,7 +1489,7 @@ mod tests { \n Projection: orders.o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists without projection @@ -1515,7 +1515,7 @@ mod tests { \n SubqueryAlias: __correlated_sq_1 [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists expressions @@ -1543,7 +1543,7 @@ mod tests { \n Projection: orders.o_custkey + Int32(1), orders.o_custkey [orders.o_custkey + Int32(1):Int64, o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with additional filters @@ -1571,7 +1571,7 @@ mod tests { \n Projection: orders.o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated exists subquery filter with disjustions @@ -1598,7 +1598,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for correlated EXISTS subquery filter @@ -1623,7 +1623,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } /// Test for single exists subquery filter @@ -1635,7 +1635,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } /// Test for single NOT exists subquery filter @@ -1647,7 +1647,7 @@ mod tests { .project(vec![col("test.b")])? .build()?; - assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), &plan) + assert_optimization_skipped(Arc::new(DecorrelatePredicateSubquery::new()), plan) } #[test] @@ -1686,7 +1686,7 @@ mod tests { \n Projection: sq2.c, sq2.a [c:UInt32, a:UInt32]\ \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1712,7 +1712,7 @@ mod tests { \n Projection: UInt32(1), sq.a [UInt32(1):UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1738,7 +1738,7 @@ mod tests { \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1766,7 +1766,7 @@ mod tests { \n Projection: sq.c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1794,7 +1794,7 @@ mod tests { \n Projection: sq.b + sq.c, sq.a [sq.b + sq.c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1822,6 +1822,6 @@ mod tests { \n Projection: UInt32(1), sq.c, sq.a [UInt32(1):UInt32, c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_duplicated_expr.rs b/datafusion/optimizer/src/eliminate_duplicated_expr.rs index de05717a72e27..fae0eb5c8b1d9 100644 --- a/datafusion/optimizer/src/eliminate_duplicated_expr.rs +++ b/datafusion/optimizer/src/eliminate_duplicated_expr.rs @@ -114,7 +114,7 @@ mod tests { use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder}; use std::sync::Arc; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(EliminateDuplicatedExpr::new()), plan, @@ -132,7 +132,7 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a, test.b, test.c\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -151,6 +151,6 @@ mod tests { let expected = "Limit: skip=5, fetch=10\ \n Sort: test.a ASC NULLS FIRST, test.b ASC NULLS LAST\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index fea14342ca774..9287752a3f992 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -88,7 +88,7 @@ mod tests { use crate::test::*; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -104,7 +104,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -119,7 +119,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -141,7 +141,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -156,7 +156,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +179,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -202,6 +202,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_join.rs b/datafusion/optimizer/src/eliminate_join.rs index 0dbebcc8a0519..f4123c6503e8b 100644 --- a/datafusion/optimizer/src/eliminate_join.rs +++ b/datafusion/optimizer/src/eliminate_join.rs @@ -82,7 +82,7 @@ mod tests { use datafusion_expr::{logical_plan::builder::LogicalPlanBuilder, Expr, LogicalPlan}; use std::sync::Arc; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateJoin::new()), plan, expected) } @@ -97,7 +97,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -114,6 +114,6 @@ mod tests { CrossJoin:\ \n EmptyRelation\ \n EmptyRelation"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 4386253740aaa..bfd660ce884fc 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -93,24 +93,19 @@ mod tests { use crate::push_down_limit::PushDownLimit; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimizer = Optimizer::with_rules(vec![Arc::new(EliminateLimit::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } fn assert_optimized_plan_eq_with_pushdown( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -124,7 +119,6 @@ mod tests { .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } @@ -137,7 +131,7 @@ mod tests { .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -157,7 +151,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -171,7 +165,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -191,7 +185,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq_with_pushdown(&plan, expected) + assert_optimized_plan_eq_with_pushdown(plan, expected) } #[test] @@ -209,7 +203,7 @@ mod tests { \n Limit: skip=0, fetch=2\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -227,7 +221,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -249,7 +243,7 @@ mod tests { \n Limit: skip=2, fetch=1\ \n TableScan: test\ \n TableScan: test1"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -262,6 +256,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_nested_union.rs b/datafusion/optimizer/src/eliminate_nested_union.rs index 5771ea2e19a29..d27766b33543e 100644 --- a/datafusion/optimizer/src/eliminate_nested_union.rs +++ b/datafusion/optimizer/src/eliminate_nested_union.rs @@ -114,7 +114,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateNestedUnion::new()), plan, expected) } @@ -131,7 +131,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { \n Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -167,7 +167,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -188,7 +188,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -210,7 +210,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -230,7 +230,7 @@ mod tests { \n TableScan: table\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // We don't need to use project_with_column_index in logical optimizer, @@ -261,7 +261,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -291,7 +291,7 @@ mod tests { \n TableScan: table\ \n Projection: table.id AS id, table.key, table.value\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -337,7 +337,7 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -384,6 +384,6 @@ mod tests { \n TableScan: table_1\ \n Projection: CAST(table_1.id AS Int64) AS id, table_1.key, CAST(table_1.value AS Float64) AS value\ \n TableScan: table_1"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_one_union.rs b/datafusion/optimizer/src/eliminate_one_union.rs index 70ee490346ffb..cb79cd88bd035 100644 --- a/datafusion/optimizer/src/eliminate_one_union.rs +++ b/datafusion/optimizer/src/eliminate_one_union.rs @@ -76,7 +76,7 @@ mod tests { ]) } - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_with_rules( vec![Arc::new(EliminateOneUnion::new())], plan, @@ -97,7 +97,7 @@ mod tests { Union\ \n TableScan: table\ \n TableScan: table"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -113,6 +113,6 @@ mod tests { }); let expected = "TableScan: table"; - assert_optimized_plan_equal(&single_union_plan, expected) + assert_optimized_plan_equal(single_union_plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 56a4a76987f75..edc2131564b7b 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -306,7 +306,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -330,7 +330,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -353,7 +353,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -380,7 +380,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -407,7 +407,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -434,6 +434,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/extract_equijoin_predicate.rs b/datafusion/optimizer/src/extract_equijoin_predicate.rs index 24664d57c38d8..efe92e2702b38 100644 --- a/datafusion/optimizer/src/extract_equijoin_predicate.rs +++ b/datafusion/optimizer/src/extract_equijoin_predicate.rs @@ -164,7 +164,7 @@ mod tests { col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType, }; - fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(ExtractEquijoinPredicate {}), plan, @@ -186,7 +186,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -205,7 +205,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -228,7 +228,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -255,7 +255,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -281,7 +281,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -318,7 +318,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } #[test] @@ -375,6 +375,6 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_plan_eq(&plan, expected) + assert_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 95cd8a9fd36ca..a91768312fcfc 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -119,7 +119,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -131,7 +131,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -142,7 +142,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -179,7 +179,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -200,7 +200,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -221,7 +221,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -244,7 +244,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index c40a9bb704ebf..d65056d1798a0 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -941,7 +941,7 @@ mod tests { UserDefinedLogicalNodeCore, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(OptimizeProjections::new()), plan, expected) } @@ -1090,7 +1090,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1104,7 +1104,7 @@ mod tests { let expected = "Projection: Int32(1) + test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1117,7 +1117,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1130,7 +1130,7 @@ mod tests { let expected = "Projection: test.a AS alias\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1152,7 +1152,7 @@ mod tests { \n Projection: \ \n Aggregate: groupBy=[[]], aggr=[[COUNT(Int32(1))]]\ \n TableScan: ?table? projection=[]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1175,7 +1175,7 @@ mod tests { .build()?; let expected = "Projection: (?table?.s)[x]\ \n TableScan: ?table? projection=[s]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1187,7 +1187,7 @@ mod tests { let expected = "Projection: (- test.a)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1199,7 +1199,7 @@ mod tests { let expected = "Projection: test.a IS NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1211,7 +1211,7 @@ mod tests { let expected = "Projection: test.a IS NOT NULL\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1223,7 +1223,7 @@ mod tests { let expected = "Projection: test.a IS TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1235,7 +1235,7 @@ mod tests { let expected = "Projection: test.a IS NOT TRUE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1247,7 +1247,7 @@ mod tests { let expected = "Projection: test.a IS FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1259,7 +1259,7 @@ mod tests { let expected = "Projection: test.a IS NOT FALSE\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1271,7 +1271,7 @@ mod tests { let expected = "Projection: test.a IS UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1283,7 +1283,7 @@ mod tests { let expected = "Projection: test.a IS NOT UNKNOWN\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1295,7 +1295,7 @@ mod tests { let expected = "Projection: NOT test.a\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1307,7 +1307,7 @@ mod tests { let expected = "Projection: TRY_CAST(test.a AS Float64)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1323,7 +1323,7 @@ mod tests { let expected = "Projection: test.a SIMILAR TO Utf8(\"[0-9]\")\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -1335,7 +1335,7 @@ mod tests { let expected = "Projection: test.a BETWEEN Int32(1) AND Int32(3)\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Test outer projection isn't discarded despite the same schema as inner @@ -1356,7 +1356,7 @@ mod tests { let expected = "Projection: test.a, CASE WHEN test.a = Int32(1) THEN Int32(10) ELSE d END AS d\ \n Projection: test.a, Int32(0) AS d\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Since only column `a` is referred at the output. Scan should only contain projection=[a]. @@ -1377,7 +1377,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses column `b` @@ -1404,7 +1404,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Only column `a` is referred at the output. However, User defined node itself uses expression `b+c` @@ -1439,7 +1439,7 @@ mod tests { let expected = "Projection: test.a, Int32(0) AS d\ \n NoOpUserDefined\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Columns `l.a`, `l.c`, `r.a` is referred at the output. @@ -1464,6 +1464,6 @@ mod tests { \n UserDefinedCrossJoin\ \n TableScan: l projection=[a, c]\ \n TableScan: r projection=[a]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 3153f72d7ee70..39f811976f8c4 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -48,10 +48,11 @@ use crate::utils::log_plan; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::instant::Instant; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::LogicalPlan; use chrono::{DateTime, Utc}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use log::{debug, warn}; /// `OptimizerRule` transforms one [`LogicalPlan`] into another which @@ -184,39 +185,12 @@ pub struct Optimizer { pub rules: Vec>, } -/// If a rule is with `ApplyOrder`, it means the optimizer will derive to handle children instead of -/// recursively handling in rule. -/// We just need handle a subtree pattern itself. +/// Specifies how recursion for an `OptimizerRule` should be handled. /// -/// Notice: **sometime** result after optimize still can be optimized, we need apply again. -/// -/// Usage Example: Merge Limit (subtree pattern is: Limit-Limit) -/// ```rust -/// use datafusion_expr::{Limit, LogicalPlan, LogicalPlanBuilder}; -/// use datafusion_common::Result; -/// fn merge_limit(parent: &Limit, child: &Limit) -> LogicalPlan { -/// // just for run -/// return parent.input.as_ref().clone(); -/// } -/// fn try_optimize(plan: &LogicalPlan) -> Result> { -/// match plan { -/// LogicalPlan::Limit(limit) => match limit.input.as_ref() { -/// LogicalPlan::Limit(child_limit) => { -/// // merge limit ... -/// let optimized_plan = merge_limit(limit, child_limit); -/// // due to optimized_plan may be optimized again, -/// // for example: plan is Limit-Limit-Limit -/// Ok(Some( -/// try_optimize(&optimized_plan)? -/// .unwrap_or_else(|| optimized_plan.clone()), -/// )) -/// } -/// _ => Ok(None), -/// }, -/// _ => Ok(None), -/// } -/// } -/// ``` +/// If an `OptimizerRule` returns `Some` from `apply_order`, it means the +/// optimizer will handle recursively applying the rule to the plan. If the +/// apply order is `None`, the rule must handle any required recursion itself. +#[derive(Debug, Clone, Copy, PartialEq)] pub enum ApplyOrder { TopDown, BottomUp, @@ -274,22 +248,82 @@ impl Optimizer { pub fn with_rules(rules: Vec>) -> Self { Self { rules } } +} + +/// Rewrites LogicalPlan nodes +struct Rewriter<'a> { + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, +} + +impl<'a> Rewriter<'a> { + fn new( + apply_order: ApplyOrder, + rule: &'a dyn OptimizerRule, + config: &'a dyn OptimizerConfig, + ) -> Self { + Self { + apply_order, + rule, + config, + } + } +} + +impl<'a> TreeNodeRewriter for Rewriter<'a> { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::TopDown { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } + fn f_up(&mut self, node: LogicalPlan) -> Result> { + if self.apply_order == ApplyOrder::BottomUp { + optimize_plan_node(node, self.rule, self.config) + } else { + Ok(Transformed::no(node)) + } + } +} + +/// Invokes the Optimizer rule to rewrite the LogicalPlan in place. +fn optimize_plan_node( + plan: LogicalPlan, + rule: &dyn OptimizerRule, + config: &dyn OptimizerConfig, +) -> Result> { + // TODO: introduce a better API to OptimizerRule to allow rewriting by ownership + rule.try_optimize(&plan, config).map(|maybe_plan| { + match maybe_plan { + Some(new_plan) => { + // if the node was rewritten by the optimizer, replace the node + Transformed::yes(new_plan) + } + None => Transformed::no(plan), + } + }) +} + +impl Optimizer { /// Optimizes the logical plan by applying optimizer rules, and /// invoking observer function after each call pub fn optimize( &self, - plan: &LogicalPlan, + plan: LogicalPlan, config: &dyn OptimizerConfig, mut observer: F, ) -> Result where F: FnMut(&LogicalPlan, &dyn OptimizerRule), { - let options = config.options(); - let mut new_plan = plan.clone(); - let start_time = Instant::now(); + let options = config.options(); + let mut new_plan = plan; let mut previous_plans = HashSet::with_capacity(16); previous_plans.insert(LogicalPlanSignature::new(&new_plan)); @@ -299,44 +333,77 @@ impl Optimizer { log_plan(&format!("Optimizer input (pass {i})"), &new_plan); for rule in &self.rules { - let result = - self.optimize_recursively(rule, &new_plan, config) - .and_then(|plan| { - if let Some(plan) = &plan { - assert_schema_is_the_same(rule.name(), plan, &new_plan)?; - } - Ok(plan) - }); - match result { - Ok(Some(plan)) => { - new_plan = plan; - observer(&new_plan, rule.as_ref()); - log_plan(rule.name(), &new_plan); + // If we need to skip failed rules, must copy plan before attempting to rewrite + // as rewriting is destructive + let prev_plan = options + .optimizer + .skip_failed_rules + .then(|| new_plan.clone()); + + let starting_schema = new_plan.schema().clone(); + + let result = match rule.apply_order() { + // optimizer handles recursion + Some(apply_order) => new_plan.rewrite(&mut Rewriter::new( + apply_order, + rule.as_ref(), + config, + )), + // rule handles recursion + None => optimize_plan_node(new_plan, rule.as_ref(), config), + } + // verify the rule didn't change the schema + .and_then(|tnr| { + if tnr.transformed { + assert_only_schema_is_the_same( + rule.name(), + &starting_schema, + &tnr.data, + )?; } - Ok(None) => { + Ok(tnr) + }); + + // Handle results + match (result, prev_plan) { + // OptimizerRule was successful + ( + Ok(Transformed { + data, transformed, .. + }), + _, + ) => { + new_plan = data; observer(&new_plan, rule.as_ref()); - debug!( - "Plan unchanged by optimizer rule '{}' (pass {})", - rule.name(), - i - ); + if transformed { + log_plan(rule.name(), &new_plan); + } else { + debug!( + "Plan unchanged by optimizer rule '{}' (pass {})", + rule.name(), + i + ); + } } - Err(e) => { - if options.optimizer.skip_failed_rules { - // Note to future readers: if you see this warning it signals a - // bug in the DataFusion optimizer. Please consider filing a ticket - // https://github.com/apache/arrow-datafusion - warn!( + // OptimizerRule was unsuccessful, but skipped failed rules is on + // so use the previous plan + (Err(e), Some(orig_plan)) => { + // Note to future readers: if you see this warning it signals a + // bug in the DataFusion optimizer. Please consider filing a ticket + // https://github.com/apache/arrow-datafusion + warn!( "Skipping optimizer rule '{}' due to unexpected error: {}", rule.name(), e ); - } else { - return Err(DataFusionError::Context( - format!("Optimizer rule '{}' failed", rule.name(),), - Box::new(e), - )); - } + new_plan = orig_plan; + } + // OptimizerRule was unsuccessful, but skipped failed rules is off, return error + (Err(e), None) => { + return Err(e.context(format!( + "Optimizer rule '{}' failed", + rule.name() + ))); } } } @@ -356,97 +423,22 @@ impl Optimizer { debug!("Optimizer took {} ms", start_time.elapsed().as_millis()); Ok(new_plan) } - - fn optimize_node( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - // TODO: future feature: We can do Batch optimize - rule.try_optimize(plan, config) - } - - fn optimize_inputs( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let inputs = plan.inputs(); - let result = inputs - .iter() - .map(|sub_plan| self.optimize_recursively(rule, sub_plan, config)) - .collect::>>()?; - if result.is_empty() || result.iter().all(|o| o.is_none()) { - return Ok(None); - } - - let new_inputs = result - .into_iter() - .zip(inputs) - .map(|(new_plan, old_plan)| match new_plan { - Some(plan) => plan, - None => old_plan.clone(), - }) - .collect(); - - let exprs = plan.expressions(); - plan.with_new_exprs(exprs, new_inputs).map(Some) - } - - /// Use a rule to optimize the whole plan. - /// If the rule with `ApplyOrder`, we don't need to recursively handle children in rule. - pub fn optimize_recursively( - &self, - rule: &Arc, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - match rule.apply_order() { - Some(order) => match order { - ApplyOrder::TopDown => { - let optimize_self_opt = self.optimize_node(rule, plan, config)?; - let optimize_inputs_opt = match &optimize_self_opt { - Some(optimized_plan) => { - self.optimize_inputs(rule, optimized_plan, config)? - } - _ => self.optimize_inputs(rule, plan, config)?, - }; - Ok(optimize_inputs_opt.or(optimize_self_opt)) - } - ApplyOrder::BottomUp => { - let optimize_inputs_opt = self.optimize_inputs(rule, plan, config)?; - let optimize_self_opt = match &optimize_inputs_opt { - Some(optimized_plan) => { - self.optimize_node(rule, optimized_plan, config)? - } - _ => self.optimize_node(rule, plan, config)?, - }; - Ok(optimize_self_opt.or(optimize_inputs_opt)) - } - }, - _ => rule.try_optimize(plan, config), - } - } } -/// Returns an error if plans have different schemas. +/// Returns an error if the plan has a different schema than `prev_schema` /// /// It ignores metadata and nullability. -pub(crate) fn assert_schema_is_the_same( +pub(crate) fn assert_only_schema_is_the_same( rule_name: &str, - prev_plan: &LogicalPlan, + prev_schema: &DFSchema, new_plan: &LogicalPlan, ) -> Result<()> { - let equivalent = new_plan - .schema() - .equivalent_names_and_types(prev_plan.schema()); + let equivalent = new_plan.schema().equivalent_names_and_types(prev_schema); if !equivalent { let e = DataFusionError::Internal(format!( "Failed due to a difference in schemas, original schema: {:?}, new schema: {:?}", - prev_plan.schema(), + prev_schema, new_plan.schema() )); Err(DataFusionError::Context( @@ -479,7 +471,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -490,7 +482,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( "Optimizer rule 'bad rule' failed\ncaused by\n\ Error during planning: rule failed", @@ -506,21 +498,27 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - let err = opt.optimize(&plan, &config, &observe).unwrap_err(); + let err = opt.optimize(plan, &config, &observe).unwrap_err(); assert_eq!( - "Optimizer rule 'get table_scan rule' failed\ncaused by\nget table_scan rule\ncaused by\n\ - Internal error: Failed due to a difference in schemas, original schema: \ - DFSchema { inner: Schema { fields: \ - [Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ - Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }, \ - field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ - functional_dependencies: FunctionalDependencies { deps: [] } }, \ + "Optimizer rule 'get table_scan rule' failed\n\ + caused by\nget table_scan rule\ncaused by\n\ + Internal error: Failed due to a difference in schemas, \ + original schema: DFSchema { inner: Schema { \ + fields: [], \ + metadata: {} }, \ + field_qualifiers: [], \ + functional_dependencies: FunctionalDependencies { deps: [] } \ + }, \ new schema: DFSchema { inner: Schema { \ - fields: [], metadata: {} }, \ - field_qualifiers: [], \ - functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ - This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", + fields: [\ + Field { name: \"a\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"b\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, \ + Field { name: \"c\", data_type: UInt32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }\ + ], \ + metadata: {} }, \ + field_qualifiers: [Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" }), Some(Bare { table: \"test\" })], \ + functional_dependencies: FunctionalDependencies { deps: [] } }.\n\ + This was likely caused by a bug in DataFusion's code and we would welcome that you file an bug report in our issue tracker", err.strip_backtrace() ); } @@ -533,7 +531,7 @@ mod tests { produce_one_row: false, schema: Arc::new(DFSchema::empty()), }); - opt.optimize(&plan, &config, &observe).unwrap(); + opt.optimize(plan, &config, &observe).unwrap(); } #[test] @@ -554,7 +552,7 @@ mod tests { // optimizing should be ok, but the schema will have changed (no metadata) assert_ne!(plan.schema().as_ref(), input_schema.as_ref()); - let optimized_plan = opt.optimize(&plan, &config, &observe)?; + let optimized_plan = opt.optimize(plan, &config, &observe)?; // metadata was removed assert_eq!(optimized_plan.schema().as_ref(), input_schema.as_ref()); Ok(()) @@ -575,7 +573,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan.clone(), &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 3 plans assert_eq!(3, plans.len()); @@ -601,7 +599,7 @@ mod tests { let mut plans: Vec = Vec::new(); let final_plan = - opt.optimize(&initial_plan, &config, |p, _| plans.push(p.clone()))?; + opt.optimize(initial_plan, &config, |p, _| plans.push(p.clone()))?; // initial_plan is not observed, so we have 4 plans assert_eq!(4, plans.len()); diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 55fb982d2a875..1a7fccc7f6a0d 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -197,12 +197,12 @@ mod tests { use super::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_eq(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } fn assert_together_optimized_plan_eq( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { assert_optimized_plan_eq_with_rules( @@ -225,7 +225,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_eq(plan, expected) } #[test] @@ -248,7 +248,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -261,7 +261,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -286,7 +286,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -311,7 +311,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -338,7 +338,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -351,7 +351,7 @@ mod tests { let plan = LogicalPlanBuilder::from(left).union(right)?.build()?; let expected = "TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -366,7 +366,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } #[test] @@ -399,6 +399,6 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected) + assert_together_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 83db4b0640a49..812957e7041c1 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1048,8 +1048,9 @@ mod tests { }; use async_trait::async_trait; + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { crate::test::assert_optimized_plan_eq( Arc::new(PushDownFilter::new()), plan, @@ -1058,29 +1059,17 @@ mod tests { } fn assert_optimized_plan_eq_with_rewrite_predicate( - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![ Arc::new(RewriteDisjunctivePredicate::new()), Arc::new(PushDownFilter::new()), ]); - let mut optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); - optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.get(1).unwrap(), - &optimized_plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + let formatted_plan = format!("{optimized_plan:?}"); - assert_eq!(plan.schema(), optimized_plan.schema()); assert_eq!(expected, formatted_plan); Ok(()) } @@ -1096,7 +1085,7 @@ mod tests { let expected = "\ Projection: test.a, test.b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1113,7 +1102,7 @@ mod tests { \n Limit: skip=0, fetch=10\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1123,7 +1112,7 @@ mod tests { .filter(lit(0i64).eq(lit(1i64)))? .build()?; let expected = "TableScan: test, full_filters=[Int64(0) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1139,7 +1128,7 @@ mod tests { Projection: test.c, test.b\ \n Projection: test.a, test.b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1153,7 +1142,7 @@ mod tests { let expected = "\ Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS total_salary]]\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1166,7 +1155,7 @@ mod tests { let expected = "Filter: test.b > Int64(10)\ \n Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1178,7 +1167,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.b + test.a]], aggr=[[SUM(test.a), test.b]]\ \n TableScan: test, full_filters=[test.b + test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1193,7 +1182,7 @@ mod tests { Filter: b > Int64(10)\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b) AS b]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that a filter is pushed to before a projection, the filter expression is correctly re-written @@ -1208,7 +1197,7 @@ mod tests { let expected = "\ Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } fn add(left: Expr, right: Expr) -> Expr { @@ -1252,7 +1241,7 @@ mod tests { let expected = "\ Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[test.a * Int32(2) + test.c = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter is pushed to after 2 projections, the filter expression is correctly re-written @@ -1284,7 +1273,7 @@ mod tests { Projection: b * Int32(3) AS a, test.c\ \n Projection: test.a * Int32(2) + test.c AS b, test.c\ \n TableScan: test, full_filters=[(test.a * Int32(2) + test.c) * Int32(3) = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[derive(Debug, PartialEq, Eq, Hash)] @@ -1347,7 +1336,7 @@ mod tests { let expected = "\ NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1364,7 +1353,7 @@ mod tests { Filter: test.c = Int64(2)\ \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1381,7 +1370,7 @@ mod tests { NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_eq(plan, expected)?; let custom_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoopPlan { @@ -1399,7 +1388,7 @@ mod tests { \n NoopPlan\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed @@ -1432,7 +1421,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when a filter with two predicates is applied after an aggregation that only allows one to be pushed, one is pushed @@ -1466,7 +1455,7 @@ mod tests { \n Aggregate: groupBy=[[b]], aggr=[[SUM(test.c)]]\ \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that when two limits are in place, we jump neither @@ -1488,7 +1477,7 @@ mod tests { \n Limit: skip=0, fetch=20\ \n Projection: test.a, test.b\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1503,7 +1492,7 @@ mod tests { let expected = "Union\ \n TableScan: test, full_filters=[test.a = Int64(1)]\ \n TableScan: test2, full_filters=[test2.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1526,7 +1515,7 @@ mod tests { \n SubqueryAlias: test2\ \n Projection: test.a AS b\ \n TableScan: test, full_filters=[test.a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1557,7 +1546,7 @@ mod tests { \n Projection: test1.d, test1.e, test1.f\ \n TableScan: test1, full_filters=[test1.d > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -1583,7 +1572,7 @@ mod tests { \n TableScan: test, full_filters=[test.a = Int32(1)]\ \n Projection: test1.a, test1.b, test1.c\ \n TableScan: test1, full_filters=[test1.a > Int32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters with the same columns are correctly placed @@ -1617,7 +1606,7 @@ mod tests { \n Projection: test.a\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters to be placed on the same depth are ANDed @@ -1647,7 +1636,7 @@ mod tests { \n Limit: skip=0, fetch=1\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// verifies that filters on a plan with user nodes are not lost @@ -1673,7 +1662,7 @@ mod tests { TestUserDefined\ \n TableScan: test, full_filters=[test.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-on-join predicates on a column common to both sides is pushed to both sides @@ -1711,7 +1700,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-using-join predicates on a column common to both sides is pushed to both sides @@ -1748,7 +1737,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from both sides are converted to join filterss @@ -1790,7 +1779,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates with columns from one side of a join are pushed only to that side @@ -1832,7 +1821,7 @@ mod tests { \n TableScan: test, full_filters=[test.b <= Int64(1)]\ \n Projection: test2.a, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the right side of a left join are not duplicated @@ -1871,7 +1860,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-join predicates on the left side of a right join are not duplicated @@ -1909,7 +1898,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-left-join predicate on a column common to both sides is only pushed to the left side @@ -1947,7 +1936,7 @@ mod tests { \n TableScan: test, full_filters=[test.a <= Int64(1)]\ \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// post-right-join predicate on a column common to both sides is only pushed to the right side @@ -1985,7 +1974,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a\ \n TableScan: test2, full_filters=[test2.a <= Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to both inputs @@ -2028,7 +2017,7 @@ mod tests { \n TableScan: test, full_filters=[test.c > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// join filter should be completely removed after pushdown @@ -2070,7 +2059,7 @@ mod tests { \n TableScan: test, full_filters=[test.b > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2110,7 +2099,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to right input @@ -2153,7 +2142,7 @@ mod tests { \n TableScan: test\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2, full_filters=[test2.c > UInt32(4)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should be pushed to left input @@ -2196,7 +2185,7 @@ mod tests { \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.a, test2.b, test2.c\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// single table predicate parts of ON condition should not be pushed @@ -2234,7 +2223,7 @@ mod tests { ); let expected = &format!("{plan:?}"); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } struct PushDownProvider { @@ -2293,7 +2282,7 @@ mod tests { let expected = "\ TableScan: test, full_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2304,7 +2293,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test, partial_filters=[a = Int64(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2312,7 +2301,7 @@ mod tests { let plan = table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?; - let optimised_plan = PushDownFilter::new() + let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new()) .expect("failed to optimize plan") .unwrap(); @@ -2323,7 +2312,7 @@ mod tests { // Optimizing the same plan multiple times should produce the same plan // each time. - assert_optimized_plan_eq(&optimised_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2334,7 +2323,7 @@ mod tests { let expected = "\ Filter: a = Int64(1)\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2363,7 +2352,7 @@ mod tests { \n Filter: a = Int64(10) AND b > Int64(11)\ \n TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2394,7 +2383,7 @@ Projection: a, b "# .trim(); - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2422,7 +2411,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2454,7 +2443,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]\ "; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2479,7 +2468,7 @@ Projection: a, b Projection: test.a AS b, test.c AS d\ \n TableScan: test, full_filters=[test.a > Int64(10), test.c > Int64(10)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// predicate on join key in filter expression should be pushed down to both inputs @@ -2519,7 +2508,7 @@ Projection: a, b \n TableScan: test, full_filters=[test.a > UInt32(1)]\ \n Projection: test2.b AS d\ \n TableScan: test2, full_filters=[test2.b > UInt32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2548,7 +2537,7 @@ Projection: a, b Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2580,7 +2569,7 @@ Projection: a, b \n Projection: test.a AS b, test.c\ \n TableScan: test, full_filters=[test.a IN ([UInt32(1), UInt32(2), UInt32(3), UInt32(4)])]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2616,7 +2605,7 @@ Projection: a, b \n Subquery:\ \n Projection: sq.c\ \n TableScan: sq"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2649,7 +2638,7 @@ Projection: a, b \n Projection: Int64(0) AS a\ \n Filter: Int64(0) = Int64(1)\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2677,14 +2666,14 @@ Projection: a, b \n TableScan: test, full_filters=[test.b > UInt32(1) OR test.c < UInt32(10)]\ \n Projection: test1.a AS d, test1.a AS e\ \n TableScan: test1"; - assert_optimized_plan_eq_with_rewrite_predicate(&plan, expected)?; + assert_optimized_plan_eq_with_rewrite_predicate(plan.clone(), expected)?; // Originally global state which can help to avoid duplicate Filters been generated and pushed down. // Now the global state is removed. Need to double confirm that avoid duplicate Filters. let optimized_plan = PushDownFilter::new() .try_optimize(&plan, &OptimizerContext::new())? .expect("failed to optimize plan"); - assert_optimized_plan_eq(&optimized_plan, expected) + assert_optimized_plan_eq(optimized_plan, expected) } #[test] @@ -2725,7 +2714,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2766,7 +2755,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2812,7 +2801,7 @@ Projection: a, b \n TableScan: test1\ \n Projection: test2.a, test2.b\ \n TableScan: test2, full_filters=[test2.b > UInt32(2)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2857,7 +2846,7 @@ Projection: a, b \n TableScan: test1, full_filters=[test1.b > UInt32(1)]\ \n Projection: test2.a, test2.b\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -2890,7 +2879,7 @@ Projection: a, b \n Projection: test1.a, SUM(test1.b), random() + Int32(1) AS r\ \n Aggregate: groupBy=[[test1.a]], aggr=[[SUM(test1.b)]]\ \n TableScan: test1, full_filters=[test1.a > Int32(5)]"; - assert_optimized_plan_eq(&plan, expected_after) + assert_optimized_plan_eq(plan, expected_after) } #[test] @@ -2932,6 +2921,6 @@ Projection: a, b \n Inner Join: test1.a = test2.a\ \n TableScan: test1\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index 33d02d5c5628e..da445c7f4cb4b 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -284,7 +284,7 @@ mod test { max, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } @@ -303,7 +303,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -321,7 +321,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -338,7 +338,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -358,7 +358,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -375,7 +375,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -392,7 +392,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -411,7 +411,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +426,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -444,7 +444,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +461,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -478,7 +478,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -494,7 +494,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -511,7 +511,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -531,7 +531,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +555,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -579,7 +579,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -608,7 +608,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -637,7 +637,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_equal(&outer_query, expected) + assert_optimized_plan_equal(outer_query, expected) } #[test] @@ -663,7 +663,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -682,7 +682,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -701,7 +701,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -719,7 +719,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -737,7 +737,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -755,7 +755,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected)?; + assert_optimized_plan_equal(plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -773,7 +773,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -798,7 +798,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -823,7 +823,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -848,7 +848,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -873,7 +873,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -893,7 +893,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -913,7 +913,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -928,7 +928,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -943,7 +943,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -960,6 +960,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/push_down_projection.rs b/datafusion/optimizer/src/push_down_projection.rs index ccdcf2f65bc8f..5d9b48af3b292 100644 --- a/datafusion/optimizer/src/push_down_projection.rs +++ b/datafusion/optimizer/src/push_down_projection.rs @@ -27,7 +27,7 @@ mod tests { use crate::optimize_projections::OptimizeProjections; use crate::optimizer::Optimizer; use crate::test::*; - use crate::OptimizerContext; + use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::builder::table_scan_with_filters; @@ -51,7 +51,7 @@ mod tests { let expected = "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -65,7 +65,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -81,7 +81,7 @@ mod tests { \n SubqueryAlias: a\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -98,7 +98,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n TableScan: test projection=[b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -123,7 +123,7 @@ mod tests { Aggregate: groupBy=[[]], aggr=[[MAX(m4.tag.one) AS tag.one]]\ \n TableScan: m4 projection=[tag.one]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -137,7 +137,7 @@ mod tests { let expected = "Projection: test.a, test.c, test.b\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -147,7 +147,7 @@ mod tests { let plan = table_scan(Some("test"), &schema, Some(vec![1, 0, 2]))?.build()?; let expected = "TableScan: test projection=[b, a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -160,7 +160,7 @@ mod tests { let expected = "Projection: test.a, test.b\ \n TableScan: test projection=[b, a]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -173,7 +173,7 @@ mod tests { let expected = "Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -195,7 +195,7 @@ mod tests { \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.b, test.a\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -215,7 +215,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -267,7 +267,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[c1]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -317,7 +317,7 @@ mod tests { \n TableScan: test projection=[a, b]\ \n TableScan: test2 projection=[a]"; - let optimized_plan = optimize(&plan)?; + let optimized_plan = optimize(plan)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -361,7 +361,7 @@ mod tests { let expected = "Projection: CAST(test.c AS Float64)\ \n TableScan: test projection=[c]"; - assert_optimized_plan_eq(&projection, expected) + assert_optimized_plan_eq(projection, expected) } #[test] @@ -377,7 +377,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -398,7 +398,7 @@ mod tests { let expected = "TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -418,7 +418,7 @@ mod tests { \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -427,7 +427,7 @@ mod tests { let plan = LogicalPlanBuilder::from(table_scan).build()?; // should expand projection to all columns without projection let expected = "TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -438,7 +438,7 @@ mod tests { .build()?; let expected = "Projection: Int64(1), Int64(2)\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes unused columns in projections @@ -457,14 +457,14 @@ mod tests { assert_fields_eq(&plan, vec!["c", "MAX(test.a)"]); - let plan = optimize(&plan).expect("failed to optimize plan"); + let plan = optimize(plan).expect("failed to optimize plan"); let expected = "\ Aggregate: groupBy=[[test.c]], aggr=[[MAX(test.a)]]\ \n Filter: test.c > Int32(1)\ \n Projection: test.c, test.a\ \n TableScan: test projection=[a, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that it removes un-needed projections @@ -486,7 +486,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -515,7 +515,7 @@ mod tests { Projection: Int32(1) AS a\ \n TableScan: test projection=[], full_filters=[b = Int32(1)]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } /// tests that optimizing twice yields same plan @@ -528,9 +528,9 @@ mod tests { .project(vec![lit(1).alias("a")])? .build()?; - let optimized_plan1 = optimize(&plan).expect("failed to optimize plan"); + let optimized_plan1 = optimize(plan).expect("failed to optimize plan"); let optimized_plan2 = - optimize(&optimized_plan1).expect("failed to optimize plan"); + optimize(optimized_plan1.clone()).expect("failed to optimize plan"); let formatted_plan1 = format!("{optimized_plan1:?}"); let formatted_plan2 = format!("{optimized_plan2:?}"); @@ -559,7 +559,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.c]], aggr=[[MAX(test.b)]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -585,7 +585,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(test.b), COUNT(test.b) FILTER (WHERE test.c > Int32(42)) AS count2]]\ \n TableScan: test projection=[a, b, c]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -602,7 +602,7 @@ mod tests { \n Distinct:\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } #[test] @@ -641,25 +641,23 @@ mod tests { \n WindowAggr: windowExpr=[[MAX(test.a) PARTITION BY [test.b] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: test projection=[a, b]"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_eq(plan, expected) } - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> { let optimized_plan = optimize(plan).expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); Ok(()) } - fn optimize(plan: &LogicalPlan) -> Result { + fn optimize(plan: LogicalPlan) -> Result { let optimizer = Optimizer::with_rules(vec![Arc::new(OptimizeProjections::new())]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let optimized_plan = + optimizer.optimize(plan, &OptimizerContext::new(), observe)?; + Ok(optimized_plan) } + + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} } diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 0055e329c29d9..d64f8506c416d 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -171,7 +171,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } @@ -194,7 +194,7 @@ mod tests { assert_optimized_plan_eq( Arc::new(ReplaceDistinctWithAggregate::new()), - &plan, + plan, expected, ) } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 8acc36e479cab..85dcec63ab589 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -427,7 +427,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -483,7 +483,7 @@ mod tests { \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -521,7 +521,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -557,7 +557,7 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -591,7 +591,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -730,7 +730,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -796,7 +796,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -835,7 +835,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -875,7 +875,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -908,7 +908,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -940,7 +940,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -971,7 +971,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1028,7 +1028,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) @@ -1077,7 +1077,7 @@ mod tests { assert_multi_rules_optimized_plan_eq_display_indent( vec![Arc::new(ScalarSubqueryToJoin::new())], - &plan, + plan, expected, ); Ok(()) diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 5b47abb308d0d..5e3dfe34a50c0 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -313,7 +313,7 @@ mod tests { min, sum, AggregateFunction, }; - fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), plan, @@ -335,7 +335,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -352,7 +352,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -373,7 +373,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -391,7 +391,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -410,7 +410,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -426,7 +426,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -443,7 +443,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -461,7 +461,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -490,7 +490,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -508,7 +508,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -525,7 +525,7 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -555,7 +555,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -574,7 +574,7 @@ mod tests { \n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -593,7 +593,7 @@ mod tests { \n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -616,7 +616,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -639,7 +639,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -662,7 +662,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -685,7 +685,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } #[test] @@ -708,6 +708,6 @@ mod tests { let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) + assert_optimized_plan_equal(plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index e691fe9a53516..b8e9c66bc2b45 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -16,7 +16,7 @@ // under the License. use crate::analyzer::{Analyzer, AnalyzerRule}; -use crate::optimizer::{assert_schema_is_the_same, Optimizer}; +use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::config::ConfigOptions; @@ -152,20 +152,16 @@ pub fn assert_analyzer_check_err( } pub fn assert_optimized_plan_eq( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { - let optimizer = Optimizer::with_rules(vec![rule.clone()]); - let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + + // in tests we are applying only one rule once + let opt_context = OptimizerContext::new().with_max_passes(1); - // Ensure schemas always match after an optimization - assert_schema_is_the_same(rule.name(), plan, &optimized_plan)?; + let optimizer = Optimizer::with_rules(vec![rule.clone()]); + let optimized_plan = optimizer.optimize(plan, &opt_context, observe)?; let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); @@ -174,7 +170,7 @@ pub fn assert_optimized_plan_eq( pub fn assert_optimized_plan_eq_with_rules( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) -> Result<()> { fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} @@ -187,58 +183,46 @@ pub fn assert_optimized_plan_eq_with_rules( .expect("failed to optimize plan"); let formatted_plan = format!("{optimized_plan:?}"); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); Ok(()) } +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + pub fn assert_optimized_plan_eq_display_indent( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); let optimized_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ) - .expect("failed to optimize plan") - .unwrap_or_else(|| plan.clone()); + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_multi_rules_optimized_plan_eq_display_indent( rules: Vec>, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(rules); - let mut optimized_plan = plan.clone(); - for rule in &optimizer.rules { - optimized_plan = optimizer - .optimize_recursively(rule, &optimized_plan, &OptimizerContext::new()) - .expect("failed to optimize plan") - .unwrap_or_else(|| optimized_plan.clone()); - } + let optimized_plan = optimizer + .optimize(plan, &OptimizerContext::new(), observe) + .expect("failed to optimize plan"); let formatted_plan = optimized_plan.display_indent_schema().to_string(); assert_eq!(formatted_plan, expected); } pub fn assert_optimizer_err( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, expected: &str, ) { let optimizer = Optimizer::with_rules(vec![rule]); - let res = optimizer.optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - ); + let res = optimizer.optimize(plan, &OptimizerContext::new(), observe); match res { - Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), + Ok(plan) => assert_eq!(format!("{}", plan.display_indent()), "An error"), Err(ref e) => { let actual = format!("{e}"); if expected.is_empty() || !actual.contains(expected) { @@ -250,16 +234,11 @@ pub fn assert_optimizer_err( pub fn assert_optimization_skipped( rule: Arc, - plan: &LogicalPlan, + plan: LogicalPlan, ) -> Result<()> { let optimizer = Optimizer::with_rules(vec![rule]); - let new_plan = optimizer - .optimize_recursively( - optimizer.rules.first().unwrap(), - plan, - &OptimizerContext::new(), - )? - .unwrap_or_else(|| plan.clone()); + let new_plan = optimizer.optimize(plan.clone(), &OptimizerContext::new(), observe)?; + assert_eq!( format!("{}", plan.display_indent()), format!("{}", new_plan.display_indent()) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index acafc0bafaf46..c28349447dbb3 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -315,7 +315,7 @@ fn test_sql(sql: &str) -> Result { let optimizer = Optimizer::new(); // analyze and optimize the logical plan let plan = analyzer.execute_and_check(&plan, config.options(), |_, _| {})?; - optimizer.optimize(&plan, &config, |_, _| {}) + optimizer.optimize(plan, &config, |_, _| {}) } #[derive(Default)] diff --git a/datafusion/sqllogictest/test_files/join.slt b/datafusion/sqllogictest/test_files/join.slt index da9b4168e7e09..135ab80754253 100644 --- a/datafusion/sqllogictest/test_files/join.slt +++ b/datafusion/sqllogictest/test_files/join.slt @@ -587,7 +587,7 @@ FROM t1 ---- 11 11 11 -# subsequent inner join +# subsequent inner join query III rowsort SELECT t1.t1_id, t2.t2_id, t3.t3_id FROM t1