From 2b49890545a28b84086731146a5c58431368fbbc Mon Sep 17 00:00:00 2001 From: Jia-Xuan Liu Date: Thu, 14 Nov 2024 16:50:59 +0800 Subject: [PATCH] apply timestamp simplify rule before type coercion --- .../optimize/simplify_timestamp.rs | 43 +++++----------- wren-core/core/src/mdl/context.rs | 4 +- wren-core/core/src/mdl/mod.rs | 50 +++++++++++++++++-- 3 files changed, 62 insertions(+), 35 deletions(-) diff --git a/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs b/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs index d837f64cf..b29c2f517 100644 --- a/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs +++ b/wren-core/core/src/logical_plan/optimize/simplify_timestamp.rs @@ -17,19 +17,21 @@ * under the License. */ use datafusion::arrow::datatypes::{DataType, TimeUnit}; -use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; use datafusion::common::ScalarValue::{ TimestampMicrosecond, TimestampMillisecond, TimestampSecond, }; use datafusion::common::{DFSchema, DFSchemaRef, Result, ScalarValue}; +use datafusion::config::ConfigOptions; use datafusion::execution::context::ExecutionProps; use datafusion::logical_expr::expr_rewriter::NamePreserver; use datafusion::logical_expr::simplify::SimplifyContext; use datafusion::logical_expr::utils::merge_schema; use datafusion::logical_expr::{cast, Cast, LogicalPlan, TryCast}; -use datafusion::optimizer::optimizer::ApplyOrder; use datafusion::optimizer::simplify_expressions::ExprSimplifier; -use datafusion::optimizer::{OptimizerConfig, OptimizerRule}; +use datafusion::optimizer::AnalyzerRule; use datafusion::prelude::Expr; use datafusion::scalar::ScalarValue::TimestampNanosecond; use std::sync::Arc; @@ -46,37 +48,18 @@ impl TimestampSimplify { } } -impl OptimizerRule for TimestampSimplify { - fn name(&self) -> &str { - "simplify_cast_expressions" - } - - fn apply_order(&self) -> Option { - Some(ApplyOrder::BottomUp) +impl AnalyzerRule for TimestampSimplify { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_internal(plan).data() } - fn supports_rewrite(&self) -> bool { - true - } - - /// if supports_owned returns true, the Optimizer calls - /// [`Self::rewrite`] instead of [`Self::try_optimize`] - fn rewrite( - &self, - plan: LogicalPlan, - config: &dyn OptimizerConfig, - ) -> Result> { - let mut execution_props = ExecutionProps::new(); - execution_props.query_execution_start_time = config.query_execution_start_time(); - Self::optimize_internal(plan, &execution_props) + fn name(&self) -> &str { + "simplify_timestamp_expressions" } } impl TimestampSimplify { - fn optimize_internal( - plan: LogicalPlan, - execution_props: &ExecutionProps, - ) -> Result> { + fn analyze_internal(plan: LogicalPlan) -> Result> { let schema = if !plan.inputs().is_empty() { DFSchemaRef::new(merge_schema(&plan.inputs())) } else if let LogicalPlan::TableScan(scan) = &plan { @@ -97,8 +80,8 @@ impl TimestampSimplify { } else { Arc::new(DFSchema::empty()) }; - - let info = SimplifyContext::new(execution_props).with_schema(schema); + let execution_props = ExecutionProps::default(); + let info = SimplifyContext::new(&execution_props).with_schema(schema); // Inputs have already been rewritten (due to bottom-up traversal handled by Optimizer) // Just need to rewrite our own expressions diff --git a/wren-core/core/src/mdl/context.rs b/wren-core/core/src/mdl/context.rs index 5b02241e2..7b4f30305 100644 --- a/wren-core/core/src/mdl/context.rs +++ b/wren-core/core/src/mdl/context.rs @@ -140,6 +140,9 @@ fn analyze_rule_for_unparsing( Arc::new(InlineTableScan::new()), // Every rule that will generate [Expr::Wildcard] should be placed in front of [ExpandWildcardRule]. Arc::new(ExpandWildcardRule::new()), + // TimestampSimplify should be placed before TypeCoercion because the simplified timestamp should + // be casted to the target type if needed + Arc::new(TimestampSimplify::new()), // [Expr::Wildcard] should be expanded before [TypeCoercion] Arc::new(TypeCoercion::new()), // Disable it to avoid generate the alias name, `count(*)` because BigQuery doesn't allow @@ -180,7 +183,6 @@ fn optimize_rule_for_unparsing() -> Vec> { Arc::new(SingleDistinctToGroupBy::new()), // Disable SimplifyExpressions to avoid apply some function locally // Arc::new(SimplifyExpressions::new()), - Arc::new(TimestampSimplify::new()), Arc::new(UnwrapCastInComparison::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateGroupByConstant::new()), diff --git a/wren-core/core/src/mdl/mod.rs b/wren-core/core/src/mdl/mod.rs index 87d82bf9f..a787c4257 100644 --- a/wren-core/core/src/mdl/mod.rs +++ b/wren-core/core/src/mdl/mod.rs @@ -967,7 +967,7 @@ mod test { .build(), ) .column( - ColumnBuilder::new("cast_timestamp", "timestamp") + ColumnBuilder::new("cast_timestamptz", "timestamptz") .expression(r#"cast("出道時間" as timestamp with time zone)"#) .build(), ) @@ -976,7 +976,7 @@ mod test { .build(); let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze(manifest)?); - let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamp as timestamp) > timestamp '2011-01-01 21:00:00'"#; + let sql = r#"select count(*) from wren.test.artist where cast(cast_timestamptz as timestamp) > timestamp '2011-01-01 21:00:00'"#; let actual = transform_sql_with_ctx( &SessionContext::new(), Arc::clone(&analyzed_mdl), @@ -985,8 +985,8 @@ mod test { ) .await?; assert_eq!(actual, - "SELECT count(*) FROM (SELECT artist.cast_timestamp FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamp \ - FROM artist) AS artist) AS artist WHERE artist.cast_timestamp > CAST('2011-01-01 21:00:00' AS TIMESTAMP)"); + "SELECT count(*) FROM (SELECT artist.cast_timestamptz FROM (SELECT CAST(artist.\"出道時間\" AS TIMESTAMP WITH TIME ZONE) AS cast_timestamptz \ + FROM artist) AS artist) AS artist WHERE CAST(artist.cast_timestamptz AS TIMESTAMP) > CAST('2011-01-01 21:00:00' AS TIMESTAMP)"); Ok(()) } @@ -1071,6 +1071,48 @@ mod test { (SELECT timestamp_table.timestamp_col, timestamp_table.timestamptz_col FROM \ (SELECT timestamp_table.timestamp_col AS timestamp_col, timestamp_table.timestamptz_col AS timestamptz_col \ FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamptz_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the simplified literal will be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamptz_col > CAST(CAST('2011-01-01 18:00:00' AS TIMESTAMP) AS TIMESTAMP WITH TIME ZONE) \ + FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamptz_col > '2011-01-01 18:00:00' from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the string literal will be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamptz_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP WITH TIME ZONE) \ + FROM (SELECT timestamp_table.timestamptz_col FROM (SELECT timestamp_table.timestamptz_col AS timestamptz_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); + + let sql = r#"select timestamp_col > cast('2011-01-01 18:00:00' as TIMESTAMP WITH TIME ZONE) from wren.test.timestamp_table"#; + let actual = transform_sql_with_ctx( + &SessionContext::new(), + Arc::clone(&analyzed_mdl), + &[], + sql, + ) + .await?; + // assert the simplified literal won't be casted to the timestamp tz + assert_eq!(actual, + "SELECT timestamp_table.timestamp_col > CAST('2011-01-01 18:00:00' AS TIMESTAMP) FROM \ + (SELECT timestamp_table.timestamp_col FROM (SELECT timestamp_table.timestamp_col AS timestamp_col \ + FROM datafusion.public.timestamp_table) AS timestamp_table) AS timestamp_table"); } Ok(()) }