Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unparsing optimized (> 2 inputs) unions #14031

Merged
merged 10 commits into from
Jan 9, 2025
26 changes: 14 additions & 12 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,6 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Union(union) => {
if union.inputs.len() != 2 {
return not_impl_err!(
"UNION ALL expected 2 inputs, but found {}",
union.inputs.len()
);
}

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive_with_dialect_alias(
Expand All @@ -729,11 +722,20 @@ impl Unparser<'_> {
.map(|input| self.select_to_sql_expr(input, query))
.collect::<Result<Vec<_>>>()?;

let union_expr = SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(input_exprs[0].clone()),
right: Box::new(input_exprs[1].clone()),
// Build the union expression tree bottom-up by reversing the order
// note that we are also swapping left and right inputs because of the rev
let Some(union_expr) =
input_exprs
.into_iter()
.rev()
.reduce(|a, b| SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(b),
right: Box::new(a),
})
else {
return internal_err!("UNION operator requires at least 2 inputs");
};

let Some(query) = query.as_mut() else {
Expand Down
41 changes: 37 additions & 4 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::*;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
use datafusion_expr::{
col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
Expand All @@ -42,7 +42,7 @@ use std::{fmt, vec};

use crate::common::{MockContextProvider, MockSessionState};
use datafusion_expr::builder::{
table_scan_with_filter_and_fetch, table_scan_with_filters,
project, table_scan_with_filter_and_fetch, table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
Expand Down Expand Up @@ -1615,3 +1615,36 @@ fn test_unparse_extension_to_sql() -> Result<()> {
}
Ok(())
}

#[test]
fn test_unparse_optimized_multi_union() -> Result<()> {
let unparser = Unparser::default();

let schema = Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);

let dfschema = Arc::new(DFSchema::try_from(schema)?);

let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: dfschema.clone(),
});

let plan = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(),
],
schema: dfschema.clone(),
});

let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y";

assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql);

Ok(())
}
Loading