Skip to content

Commit

Permalink
[GLUTEN-5016][CH] fix simple aggregation sql exchange fallback (#5042)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwz9103 authored Mar 22, 2024
1 parent 201f322 commit 0702eb8
Show file tree
Hide file tree
Showing 2 changed files with 163 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,19 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit
|""".stripMargin) { _ => }
assert(result(0).getLong(0) == 227302L)
}

test("test 'GLUTEN-5016'") {
withSQLConf(("spark.gluten.sql.columnar.preferColumnar", "false")) {
val sql =
"""
|SELECT
| sum(l_quantity) AS sum_qty
|FROM
| lineitem
|WHERE
| l_shipdate <= date'1998-09-02'
|""".stripMargin
runSql(sql, noFallBack = true) { _ => }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,12 @@ object MiscColumnarRules {

// Aggregation transformation.
private case class AggregationTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case agg: HashAggregateExec =>
genHashAggregateExec(agg)
case other => other
}

/**
Expand Down Expand Up @@ -105,13 +108,144 @@ object MiscColumnarRules {
}
}

// Exchange transformation.
private case class ExchangeTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case plan if TransformHints.isNotTransformable(plan) =>
plan
case plan: ShuffleExchangeExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val child = plan.child
if (
(child.supportsColumnar || GlutenConfig.getConf.enablePreferColumnar) &&
BackendsApiManager.getSettings.supportColumnarShuffleExec()
) {
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
} else {
plan.withNewChildren(Seq(child))
}
case plan: BroadcastExchangeExec =>
val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarBroadcastExchangeExec(plan.mode, child)
case other => other
}
}

// Join transformation.
private case class JoinTransformRule() extends Rule[SparkPlan] with LogLevelUtil {

/**
* Get the build side supported by the execution of vanilla Spark.
*
* @param plan
* : shuffled hash join plan
* @return
* the supported build side
*/
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}

override def apply(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
// Because we manually removed the build side limitation for LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join fallback into vanilla
// Spark for execution.
return ShuffledHashJoinExec(
shj.leftKeys,
shj.rightKeys,
shj.joinType,
getSparkSupportedBuildSide(shj),
shj.condition,
shj.left,
shj.right,
shj.isSkewJoin
)
} else {
return shj
}
case p =>
return p
}
}
plan match {
case plan: ShuffledHashJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
SortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case other => other
}
}

}

// Filter transformation.
private case class FilterTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
private val replace = new ReplaceSingleNode()

override def apply(plan: SparkPlan): SparkPlan = plan.transformDown {
override def apply(plan: SparkPlan): SparkPlan = plan match {
case filter: FilterExec =>
genFilterExec(filter)
case other => other
}

/**
Expand Down Expand Up @@ -155,39 +289,18 @@ object MiscColumnarRules {
private case class RegularTransformRule() extends Rule[SparkPlan] with LogLevelUtil {
private val replace = new ReplaceSingleNode()

override def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case plan => replace.replaceWithTransformerPlan(plan)
}
override def apply(plan: SparkPlan): SparkPlan = replace.replaceWithTransformerPlan(plan)
}

// Utility to replace single node within transformed Gluten node.
// Children will be preserved as they are as children of the output node.
class ReplaceSingleNode() extends LogLevelUtil with Logging {
private val columnarConf: GlutenConfig = GlutenConfig.getConf

def replaceWithTransformerPlan(p: SparkPlan): SparkPlan = {
val plan = p
if (TransformHints.isNotTransformable(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
plan match {
case shj: ShuffledHashJoinExec =>
if (BackendsApiManager.getSettings.recreateJoinExecOnFallback()) {
// Because we manually removed the build side limitation for LeftOuter, LeftSemi and
// RightOuter, need to change the build side back if this join fallback into vanilla
// Spark for execution.
return ShuffledHashJoinExec(
shj.leftKeys,
shj.rightKeys,
shj.joinType,
getSparkSupportedBuildSide(shj),
shj.condition,
shj.left,
shj.right,
shj.isSkewJoin
)
} else {
return shj
}
case plan: BatchScanExec =>
return applyScanNotTransformable(plan)
case plan: FileSourceScanExec =>
Expand Down Expand Up @@ -283,75 +396,6 @@ object MiscColumnarRules {
plan.projectList,
child,
offset)
case plan: ShuffleExchangeExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
val child = plan.child
if (
(child.supportsColumnar || columnarConf.enablePreferColumnar) &&
BackendsApiManager.getSettings.supportColumnarShuffleExec()
) {
BackendsApiManager.getSparkPlanExecApiInstance.genColumnarShuffleExchange(plan, child)
} else {
plan.withNewChildren(Seq(child))
}
case plan: ShuffledHashJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
SortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastExchangeExec =>
val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
ColumnarBroadcastExchangeExec(plan.mode, child)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case plan: WindowExec =>
WindowExecTransformer(
plan.windowExpression,
Expand Down Expand Up @@ -389,22 +433,6 @@ object MiscColumnarRules {
}
}

/**
* Get the build side supported by the execution of vanilla Spark.
*
* @param plan
* : shuffled hash join plan
* @return
* the supported build side
*/
private def getSparkSupportedBuildSide(plan: ShuffledHashJoinExec): BuildSide = {
plan.joinType match {
case LeftOuter | LeftSemi => BuildRight
case RightOuter => BuildLeft
case _ => plan.buildSide
}
}

private def applyScanNotTransformable(plan: SparkPlan): SparkPlan = plan match {
case plan: FileSourceScanExec =>
val newPartitionFilters =
Expand Down Expand Up @@ -489,18 +517,23 @@ object MiscColumnarRules {
case class TransformPreOverrides() extends Rule[SparkPlan] with LogLevelUtil {
import TransformPreOverrides._

private val subRules = List(
FilterTransformRule(),
private val topdownRules = List(
FilterTransformRule()
)
private val bottomupRules = List(
RegularTransformRule(),
AggregationTransformRule()
AggregationTransformRule(),
ExchangeTransformRule(),
JoinTransformRule()
)

@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()

def apply(plan: SparkPlan): SparkPlan = {
val newPlan = subRules.foldLeft(plan)((p, rule) => rule.apply(p))
planChangeLogger.logRule(ruleName, plan, newPlan)
newPlan
val plan0 = topdownRules.foldLeft(plan)((p, rule) => p.transformDown { case p => rule(p) })
val plan1 = bottomupRules.foldLeft(plan0)((p, rule) => p.transformUp { case p => rule(p) })
planChangeLogger.logRule(ruleName, plan, plan1)
plan1
}
}

Expand Down

0 comments on commit 0702eb8

Please sign in to comment.