Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unparsing optimized (> 2 inputs) unions #14031

Merged
merged 10 commits into from
Jan 9, 2025
23 changes: 10 additions & 13 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,6 @@ impl Unparser<'_> {
Ok(())
}
LogicalPlan::Union(union) => {
if union.inputs.len() != 2 {
return not_impl_err!(
"UNION ALL expected 2 inputs, but found {}",
union.inputs.len()
);
}

// Covers cases where the UNION is a subquery and the projection is at the top level
if select.already_projected() {
return self.derive_with_dialect_alias(
Expand All @@ -729,12 +722,16 @@ impl Unparser<'_> {
.map(|input| self.select_to_sql_expr(input, query))
.collect::<Result<Vec<_>>>()?;

let union_expr = SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(input_exprs[0].clone()),
right: Box::new(input_exprs[1].clone()),
};
let union_expr = input_exprs
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
.into_iter()
.rev()
MohamedAbdeen21 marked this conversation as resolved.
Show resolved Hide resolved
.reduce(|a, b| SetExpr::SetOperation {
op: ast::SetOperator::Union,
set_quantifier: ast::SetQuantifier::All,
left: Box::new(b),
right: Box::new(a),
})
.unwrap();

let Some(query) = query.as_mut() else {
return internal_err!(
Expand Down
41 changes: 37 additions & 4 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::*;
use arrow_schema::{DataType, Field, Schema};
use datafusion_common::{assert_contains, DFSchema, DFSchemaRef, Result, TableReference};
use datafusion_expr::test::function_stub::{
count_udaf, max_udaf, min_udaf, sum, sum_udaf,
};
use datafusion_expr::{
col, lit, table_scan, wildcard, Expr, Extension, LogicalPlan, LogicalPlanBuilder,
UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
col, lit, table_scan, wildcard, EmptyRelation, Expr, Extension, LogicalPlan,
LogicalPlanBuilder, Union, UserDefinedLogicalNode, UserDefinedLogicalNodeCore,
};
use datafusion_functions::unicode;
use datafusion_functions_aggregate::grouping::grouping_udaf;
Expand All @@ -42,7 +42,7 @@ use std::{fmt, vec};

use crate::common::{MockContextProvider, MockSessionState};
use datafusion_expr::builder::{
table_scan_with_filter_and_fetch, table_scan_with_filters,
project, table_scan_with_filter_and_fetch, table_scan_with_filters,
};
use datafusion_functions::core::planner::CoreFunctionPlanner;
use datafusion_functions_nested::extract::array_element_udf;
Expand Down Expand Up @@ -1615,3 +1615,36 @@ fn test_unparse_extension_to_sql() -> Result<()> {
}
Ok(())
}

#[test]
fn test_unparse_optimized_multi_union() -> Result<()> {
let unparser = Unparser::default();

let schema = Schema::new(vec![
Field::new("x", DataType::Int32, false),
Field::new("y", DataType::Utf8, false),
]);

let dfschema = Arc::new(DFSchema::try_from(schema)?);

let empty = LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: true,
schema: dfschema.clone(),
});

let plan = LogicalPlan::Union(Union {
inputs: vec![
project(empty.clone(), vec![lit(1).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(1).alias("x"), lit("b").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("a").alias("y")])?.into(),
project(empty.clone(), vec![lit(2).alias("x"), lit("c").alias("y")])?.into(),
],
schema: dfschema.clone(),
});

let sql = "SELECT 1 AS x, 'a' AS y UNION ALL SELECT 1 AS x, 'b' AS y UNION ALL SELECT 2 AS x, 'a' AS y UNION ALL SELECT 2 AS x, 'c' AS y";

assert_eq!(unparser.plan_to_sql(&plan)?.to_string(), sql);

Ok(())
}
Loading