Skip to content

Commit

Permalink
apply the first unparsing result only
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal committed Dec 23, 2024
1 parent abc23f0 commit 5233e4a
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 29 deletions.
13 changes: 7 additions & 6 deletions datafusion-examples/examples/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use datafusion_sql::unparser::ast::{
DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder,
};
use datafusion_sql::unparser::dialect::CustomDialectBuilder;
use datafusion_sql::unparser::extension_unparser::UnparseResult;
use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser;
use datafusion_sql::unparser::{plan_to_sql, Unparser};
use std::fmt;
Expand Down Expand Up @@ -213,12 +214,12 @@ impl UserDefinedLogicalNodeUnparser for PlanToStatement {
&self,
node: &dyn UserDefinedLogicalNode,
unparser: &Unparser,
) -> Result<Option<Statement>> {
) -> Result<UnparseResult> {
if let Some(plan) = node.as_any().downcast_ref::<MyLogicalPlan>() {
let input = unparser.plan_to_sql(&plan.input)?;
Ok(Some(input))
Ok(UnparseResult::Statement(input))
} else {
Ok(None)
Ok(UnparseResult::Original)
}
}
}
Expand Down Expand Up @@ -260,10 +261,10 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
relation: &mut Option<&mut RelationBuilder>,
) -> Result<()> {
) -> Result<UnparseResult> {
if let Some(plan) = node.as_any().downcast_ref::<MyLogicalPlan>() {
let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else {
return Ok(());
return Ok(UnparseResult::Original);
};
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.subquery(input);
Expand All @@ -272,7 +273,7 @@ impl UserDefinedLogicalNodeUnparser for PlanToSubquery {
rel.derived(derived_builder);
}
}
Ok(())
Ok(UnparseResult::WithinStatement)
}
}

Expand Down
18 changes: 14 additions & 4 deletions datafusion/sql/src/unparser/extension_unparser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ pub trait UserDefinedLogicalNodeUnparser {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
_relation: &mut Option<&mut RelationBuilder>,
) -> datafusion_common::Result<()> {
Ok(())
) -> datafusion_common::Result<UnparseResult> {
Ok(UnparseResult::Original)
}

/// Unparse the custom logical node to a statement.
Expand All @@ -44,7 +44,17 @@ pub trait UserDefinedLogicalNodeUnparser {
&self,
_node: &dyn UserDefinedLogicalNode,
_unparser: &Unparser,
) -> datafusion_common::Result<Option<Statement>> {
Ok(None)
) -> datafusion_common::Result<UnparseResult> {
Ok(UnparseResult::Original)
}
}

/// The result of unparsing a custom logical node.
pub enum UnparseResult {
/// If the custom logical node was successfully unparsed and return a statement.
Statement(Statement),
/// If the custom logical node was successfully unparsed within a statement.
WithinStatement,
/// If the custom logical node wasn't unparsed.
Original,
}
30 changes: 25 additions & 5 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use super::{
Unparser,
};
use crate::unparser::ast::UnnestRelationBuilder;
use crate::unparser::extension_unparser::UnparseResult;
use crate::unparser::utils::unproject_agg_exprs;
use crate::utils::UNNEST_PLACEHOLDER;
use datafusion_common::{
Expand Down Expand Up @@ -126,14 +127,25 @@ impl Unparser<'_> {

/// Try to unparse a [UserDefinedLogicalNode] to a SQL statement.
/// If multiple unparsers are registered for the same [UserDefinedLogicalNode],
/// the last unparsing result will be returned.
/// the first unparsing result will be returned.
fn extension_to_statement(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<ast::Statement> {
let mut statement = None;
for unparser in &self.extension_unparsers {
statement = unparser.unparse_to_statement(node, self)?;
match unparser.unparse_to_statement(node, self)? {
UnparseResult::Statement(stmt) => {
statement = Some(stmt);
break;
}
UnparseResult::WithinStatement => {
return not_impl_err!(
"UnparseResult::WithinStatement is not supported for `extension_to_statement`"
);
}
UnparseResult::Original => {}
}
}
if let Some(statement) = statement {
Ok(statement)
Expand All @@ -144,7 +156,7 @@ impl Unparser<'_> {

/// Try to unparse a [UserDefinedLogicalNode] to a SQL statement.
/// If multiple unparsers are registered for the same [UserDefinedLogicalNode],
/// all of them will be called in order.
/// the first unparser supporting the node will be used.
fn extension_to_sql(
&self,
node: &dyn UserDefinedLogicalNode,
Expand All @@ -153,9 +165,17 @@ impl Unparser<'_> {
relation: &mut Option<&mut RelationBuilder>,
) -> Result<()> {
for unparser in &self.extension_unparsers {
unparser.unparse(node, self, query, select, relation)?;
match unparser.unparse(node, self, query, select, relation)? {
UnparseResult::WithinStatement => return Ok(()),
UnparseResult::Original => {}
UnparseResult::Statement(_) => {
return not_impl_err!(
"UnparseResult::Statement is not supported for `extension_to_sql`"
);
}
}
}
Ok(())
not_impl_err!("Unsupported extension node: {node:?}")
}

fn select_to_sql_statement(&self, plan: &LogicalPlan) -> Result<ast::Statement> {
Expand Down
57 changes: 43 additions & 14 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ use datafusion_functions_nested::planner::{FieldAccessPlanner, NestedFunctionPla
use datafusion_sql::unparser::ast::{
DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder,
};
use datafusion_sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser;
use datafusion_sql::unparser::extension_unparser::{
UnparseResult, UserDefinedLogicalNodeUnparser,
};
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
use sqlparser::parser::Parser;

Expand Down Expand Up @@ -1459,16 +1461,39 @@ impl UserDefinedLogicalNodeUnparser for MockStatementUnparser {
&self,
node: &dyn UserDefinedLogicalNode,
unparser: &Unparser,
) -> Result<Option<Statement>> {
) -> Result<UnparseResult> {
if let Some(plan) = node.as_any().downcast_ref::<MockUserDefinedLogicalPlan>() {
let input = unparser.plan_to_sql(&plan.input)?;
Ok(Some(input))
Ok(UnparseResult::Statement(input))
} else {
Ok(None)
Ok(UnparseResult::Original)
}
}
}

struct UnusedUnparser {}

impl UserDefinedLogicalNodeUnparser for UnusedUnparser {
fn unparse(
&self,
_node: &dyn UserDefinedLogicalNode,
_unparser: &Unparser,
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
_relation: &mut Option<&mut RelationBuilder>,
) -> Result<UnparseResult> {
panic!("This should not be called");
}

fn unparse_to_statement(
&self,
_node: &dyn UserDefinedLogicalNode,
_unparser: &Unparser,
) -> Result<UnparseResult> {
panic!("This should not be called");
}
}

#[test]
fn test_unparse_extension_to_statement() -> Result<()> {
let dialect = GenericDialect {};
Expand All @@ -1484,8 +1509,10 @@ fn test_unparse_extension_to_statement() -> Result<()> {
let extension = LogicalPlan::Extension(Extension {
node: Arc::new(extension),
});
let unparser =
Unparser::default().with_extension_unparsers(vec![Arc::new(MockStatementUnparser {})]);
let unparser = Unparser::default().with_extension_unparsers(vec![
Arc::new(MockStatementUnparser {}),
Arc::new(UnusedUnparser {}),
]);
let sql = unparser.plan_to_sql(&extension)?;
let expected = "SELECT * FROM j1";
assert_eq!(sql.to_string(), expected);
Expand All @@ -1510,10 +1537,10 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser {
_query: &mut Option<&mut QueryBuilder>,
_select: &mut Option<&mut SelectBuilder>,
relation: &mut Option<&mut RelationBuilder>,
) -> Result<()> {
) -> Result<UnparseResult> {
if let Some(plan) = node.as_any().downcast_ref::<MockUserDefinedLogicalPlan>() {
let Statement::Query(input) = unparser.plan_to_sql(&plan.input)? else {
return Ok(());
return Ok(UnparseResult::Original);
};
let mut derived_builder = DerivedRelationBuilder::default();
derived_builder.subquery(input);
Expand All @@ -1522,7 +1549,7 @@ impl UserDefinedLogicalNodeUnparser for MockSqlUnparser {
rel.derived(derived_builder);
}
}
Ok(())
Ok(UnparseResult::WithinStatement)
}
}

Expand All @@ -1545,17 +1572,19 @@ fn test_unparse_extension_to_sql() -> Result<()> {
let plan = LogicalPlanBuilder::from(extension)
.project(vec![col("j1_id").alias("user_id")])?
.build()?;
let unparser =
Unparser::default().with_extension_unparsers(vec![Arc::new(MockSqlUnparser {})]);
let unparser = Unparser::default().with_extension_unparsers(vec![
Arc::new(MockSqlUnparser {}),
Arc::new(UnusedUnparser {}),
]);
let sql = unparser.plan_to_sql(&plan)?;
let expected = "SELECT j1.j1_id AS user_id FROM (SELECT * FROM j1)";
assert_eq!(sql.to_string(), expected);

if let Some(err) = plan_to_sql(&plan).err() {
assert_eq!(
assert_contains!(
err.to_string(),
"External error: `relation` must be initialized"
)
"This feature is not implemented: Unsupported extension node: MockUserDefinedLogicalPlan"
);
} else {
panic!("Expected error")
}
Expand Down

0 comments on commit 5233e4a

Please sign in to comment.