From 5233e4a3286b5d6393eaf4e5f86c4b6fc1791003 Mon Sep 17 00:00:00 2001 From: Jax Liu Date: Mon, 23 Dec 2024 19:56:23 +0800 Subject: [PATCH] apply the first unparsing result only --- datafusion-examples/examples/plan_to_sql.rs | 13 +++-- .../sql/src/unparser/extension_unparser.rs | 18 ++++-- datafusion/sql/src/unparser/plan.rs | 30 ++++++++-- datafusion/sql/tests/cases/plan_to_sql.rs | 57 ++++++++++++++----- 4 files changed, 89 insertions(+), 29 deletions(-) diff --git a/datafusion-examples/examples/plan_to_sql.rs b/datafusion-examples/examples/plan_to_sql.rs index 27ac7d50c9b6..2e176199d74c 100644 --- a/datafusion-examples/examples/plan_to_sql.rs +++ b/datafusion-examples/examples/plan_to_sql.rs @@ -28,6 +28,7 @@ use datafusion_sql::unparser::ast::{ DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, }; use datafusion_sql::unparser::dialect::CustomDialectBuilder; +use datafusion_sql::unparser::extension_unparser::UnparseResult; use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; use datafusion_sql::unparser::{plan_to_sql, Unparser}; use std::fmt; @@ -213,12 +214,12 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement { &self, node: &dyn UserDefinedLogicalNode, unparser: &Unparser, - ) -> Result> { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let input = unparser.plan_to_sql(&plan.input)?; - Ok(Some(input)) + Ok(UnparseResult::Statement(input)) } else { - Ok(None) + Ok(UnparseResult::Original) } } } @@ -260,10 +261,10 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, relation: &mut Option<&mut RelationBuilder>, - ) -> Result<()> { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else { - return Ok(()); + return Ok(UnparseResult::Original); }; let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.subquery(input); @@ -272,7 +273,7 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery { rel.derived(derived_builder); } } - Ok(()) + Ok(UnparseResult::WithinStatement) } } diff --git a/datafusion/sql/src/unparser/extension_unparser.rs b/datafusion/sql/src/unparser/extension_unparser.rs index 250df5fe248a..32643260cb3c 100644 --- a/datafusion/sql/src/unparser/extension_unparser.rs +++ b/datafusion/sql/src/unparser/extension_unparser.rs @@ -33,8 +33,8 @@ pub trait UserDefinedLogicalNodeUnparser { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, _relation: &mut Option<&mut RelationBuilder>, - ) -> datafusion_common::Result<()> { - Ok(()) + ) -> datafusion_common::Result { + Ok(UnparseResult::Original) } /// Unparse the custom logical node to a statement. @@ -44,7 +44,17 @@ pub trait UserDefinedLogicalNodeUnparser { &self, _node: &dyn UserDefinedLogicalNode, _unparser: &Unparser, - ) -> datafusion_common::Result> { - Ok(None) + ) -> datafusion_common::Result { + Ok(UnparseResult::Original) } } + +/// The result of unparsing a custom logical node. +pub enum UnparseResult { + /// If the custom logical node was successfully unparsed and return a statement. + Statement(Statement), + /// If the custom logical node was successfully unparsed within a statement. + WithinStatement, + /// If the custom logical node wasn't unparsed. + Original, +} diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 80126cd25477..d282100a3a0f 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -33,6 +33,7 @@ use super::{ Unparser, }; use crate::unparser::ast::UnnestRelationBuilder; +use crate::unparser::extension_unparser::UnparseResult; use crate::unparser::utils::unproject_agg_exprs; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ @@ -126,14 +127,25 @@ impl Unparser<'_> { /// Try to unparse a [UserDefinedLogicalNode] to a SQL statement. /// If multiple unparsers are registered for the same [UserDefinedLogicalNode], - /// the last unparsing result will be returned. + /// the first unparsing result will be returned. fn extension_to_statement( &self, node: &dyn UserDefinedLogicalNode, ) -> Result { let mut statement = None; for unparser in &self.extension_unparsers { - statement = unparser.unparse_to_statement(node, self)?; + match unparser.unparse_to_statement(node, self)? { + UnparseResult::Statement(stmt) => { + statement = Some(stmt); + break; + } + UnparseResult::WithinStatement => { + return not_impl_err!( + "UnparseResult::WithinStatement is not supported for `extension_to_statement`" + ); + } + UnparseResult::Original => {} + } } if let Some(statement) = statement { Ok(statement) @@ -144,7 +156,7 @@ impl Unparser<'_> { /// Try to unparse a [UserDefinedLogicalNode] to a SQL statement. /// If multiple unparsers are registered for the same [UserDefinedLogicalNode], - /// all of them will be called in order. + /// the first unparser supporting the node will be used. fn extension_to_sql( &self, node: &dyn UserDefinedLogicalNode, @@ -153,9 +165,17 @@ impl Unparser<'_> { relation: &mut Option<&mut RelationBuilder>, ) -> Result<()> { for unparser in &self.extension_unparsers { - unparser.unparse(node, self, query, select, relation)?; + match unparser.unparse(node, self, query, select, relation)? { + UnparseResult::WithinStatement => return Ok(()), + UnparseResult::Original => {} + UnparseResult::Statement(_) => { + return not_impl_err!( + "UnparseResult::Statement is not supported for `extension_to_sql`" + ); + } + } } - Ok(()) + not_impl_err!("Unsupported extension node: {node:?}") } fn select_to_sql_statement(&self, plan: &LogicalPlan) -> Result { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index 172aca3ea05b..d3f6397fef51 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -50,7 +50,9 @@ use datafusion_functions_nested::planner::{FieldAccessPlanner, NestedFunctionPla use datafusion_sql::unparser::ast::{ DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, }; -use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion_sql::unparser::extension_unparser::{ + UnparseResult, UserDefinedLogicalNodeUnparser, +}; use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect}; use sqlparser::parser::Parser; @@ -1459,16 +1461,39 @@ impl UserDefinedLogicalNodeUnparser for MockStatementUnparser { &self, node: &dyn UserDefinedLogicalNode, unparser: &Unparser, - ) -> Result> { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let input = unparser.plan_to_sql(&plan.input)?; - Ok(Some(input)) + Ok(UnparseResult::Statement(input)) } else { - Ok(None) + Ok(UnparseResult::Original) } } } +struct UnusedUnparser {} + +impl UserDefinedLogicalNodeUnparser for UnusedUnparser { + fn unparse( + &self, + _node: &dyn UserDefinedLogicalNode, + _unparser: &Unparser, + _query: &mut Option<&mut QueryBuilder>, + _select: &mut Option<&mut SelectBuilder>, + _relation: &mut Option<&mut RelationBuilder>, + ) -> Result { + panic!("This should not be called"); + } + + fn unparse_to_statement( + &self, + _node: &dyn UserDefinedLogicalNode, + _unparser: &Unparser, + ) -> Result { + panic!("This should not be called"); + } +} + #[test] fn test_unparse_extension_to_statement() -> Result<()> { let dialect = GenericDialect {}; @@ -1484,8 +1509,10 @@ fn test_unparse_extension_to_statement() -> Result<()> { let extension = LogicalPlan::Extension(Extension { node: Arc::new(extension), }); - let unparser = - Unparser::default().with_extension_unparsers(vec![Arc::new(MockStatementUnparser {})]); + let unparser = Unparser::default().with_extension_unparsers(vec![ + Arc::new(MockStatementUnparser {}), + Arc::new(UnusedUnparser {}), + ]); let sql = unparser.plan_to_sql(&extension)?; let expected = "SELECT * FROM j1"; assert_eq!(sql.to_string(), expected); @@ -1510,10 +1537,10 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser { _query: &mut Option<&mut QueryBuilder>, _select: &mut Option<&mut SelectBuilder>, relation: &mut Option<&mut RelationBuilder>, - ) -> Result<()> { + ) -> Result { if let Some(plan) = node.as_any().downcast_ref::() { let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else { - return Ok(()); + return Ok(UnparseResult::Original); }; let mut derived_builder = DerivedRelationBuilder::default(); derived_builder.subquery(input); @@ -1522,7 +1549,7 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser { rel.derived(derived_builder); } } - Ok(()) + Ok(UnparseResult::WithinStatement) } } @@ -1545,17 +1572,19 @@ fn test_unparse_extension_to_sql() -> Result<()> { let plan = LogicalPlanBuilder::from(extension) .project(vec![col("j1_id").alias("user_id")])? .build()?; - let unparser = - Unparser::default().with_extension_unparsers(vec![Arc::new(MockSqlUnparser {})]); + let unparser = Unparser::default().with_extension_unparsers(vec![ + Arc::new(MockSqlUnparser {}), + Arc::new(UnusedUnparser {}), + ]); let sql = unparser.plan_to_sql(&plan)?; let expected = "SELECT j1.j1_id AS user_id FROM (SELECT * FROM j1)"; assert_eq!(sql.to_string(), expected); if let Some(err) = plan_to_sql(&plan).err() { - assert_eq!( + assert_contains!( err.to_string(), - "External error: `relation` must be initialized" - ) + "This feature is not implemented: Unsupported extension node: MockUserDefinedLogicalPlan" + ); } else { panic!("Expected error") }