From 7e63134ff16f42dfb081518040f5d962813be1c0 Mon Sep 17 00:00:00 2001 From: lwz9103 Date: Wed, 20 Mar 2024 18:12:42 +0800 Subject: [PATCH] [GLUTEN-5016][CH] fix simple aggregation sql exchange fallback --- .../GlutenClickHouseTPCHNullableSuite.scala | 15 +++++++ .../columnar/MiscColumnarRules.scala | 40 +++++++++++-------- 2 files changed, 38 insertions(+), 17 deletions(-) diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala index 72a04f02c1ecb..42783babd124c 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala @@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit |""".stripMargin) { _ => } assert(result(0).getLong(0) == 227302L) } + + test("test 'GLUTEN-5016'") { + withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) { + val sql = + """ + |SELECT + | sum(l_quantity) AS sum_qty + |FROM + | lineitem + |WHERE + | l_shipdate <= date'1998-09-02' + |""".stripMargin + runSql(sql, noFallBack = true) { _ => } + } + } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala index 4fbadb0b50ac7..b5ef42ac43e97 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/MiscColumnarRules.scala @@ -105,6 +105,27 @@ object MiscColumnarRules { } } + // Exchange transformation. + private case class ExchangeTransformRule() extends Rule[SparkPlan] with LogLevelUtil { + override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case plan: ShuffleExchangeExec => + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + val child = plan.child + if ( + (child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) && + BackendsApiManager.getSettings.supportColumnarShuffleExec() + ) { + BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child) + } else { + plan.withNewChildren(Seq(child)) + } + case plan: BroadcastExchangeExec => + val child = plan.child + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + ColumnarBroadcastExchangeExec(plan.mode, child) + } + } + // Filter transformation. private case class FilterTransformRule() extends Rule[SparkPlan] with LogLevelUtil { private val replace = new ReplaceSingleNode() @@ -163,7 +184,6 @@ object MiscColumnarRules { // Utility to replace single node within transformed Gluten node. // Children will be preserved as they are as children of the output node. class ReplaceSingleNode() extends LogLevelUtil with Logging { - private val columnarConf: GlutenConfig = GlutenConfig.getConf def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = { val plan = p @@ -283,17 +303,6 @@ object MiscColumnarRules { plan.projectList, child, offset) - case plan: ShuffleExchangeExec => - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - val child = plan.child - if ( - (child.supportsColumnar || columnarConf.enablePreferColumnar) && - BackendsApiManager.getSettings.supportColumnarShuffleExec() - ) { - BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child) - } else { - plan.withNewChildren(Seq(child)) - } case plan: ShuffledHashJoinExec => val left = plan.left val right = plan.right @@ -320,10 +329,6 @@ object MiscColumnarRules { left, right, plan.isSkewJoin) - case plan: BroadcastExchangeExec => - val child = plan.child - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ColumnarBroadcastExchangeExec(plan.mode, child) case plan: BroadcastHashJoinExec => val left = plan.left val right = plan.right @@ -492,7 +497,8 @@ object MiscColumnarRules { private val subRules = List( FilterTransformRule(), RegularTransformRule(), - AggregationTransformRule() + AggregationTransformRule(), + ExchangeTransformRule() ) @transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()