Skip to content

Commit

Permalink
KE-42300 support split source partition
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Gan committed Oct 17, 2023
1 parent 82c4992 commit e913aa8
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
}
}

def transformOnceWithPruning(breaker: BaseType => Boolean, ruleId: RuleId = UnknownRuleId)(
rule: PartialFunction[BaseType, BaseType]): BaseType = {
if (breaker.apply(this) || isRuleIneffective(ruleId)) {
return this
}

val afterRule = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(this, identity[BaseType])
}
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
val rewritten_plan =
mapChildren(_.transformOnceWithPruning(breaker, ruleId)(rule))
if (this eq rewritten_plan) {
markRuleAsIneffective(ruleId)
this
} else {
rewritten_plan
}
} else {
// If the transform function replaces this node with a new one, carry over the tags.
afterRule.copyTagsFrom(this)
afterRule
}
}

/**
* Returns a copy of this node where `rule` has been recursively applied first to all of its
* children and then itself (post-order). When `rule` does not apply to a given node, it is left
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,28 @@ object SQLConf {
.checkValue(_ > 0, "advisoryPartitionSizeInBytes must be positive")
.createWithDefaultString("64MB")

val SPLIT_SOURCE_PARTITION_ENABLED =
buildConf("spark.sql.splitSourcePartition.enabled")
.doc("When true, split source partition.")
.version("3.2.0")
.booleanConf
.createWithDefault(false)

val SPLIT_SOURCE_PARTITION_THRESHOLD =
buildConf("spark.sql.splitSourcePartition.thresholdInBytes")
.doc("A partition is considered to be split.")
.version("3.0.0")
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("1MB")

val SPLIT_SOURCE_PARTITION_MAXEXPANDNUM =
buildConf("spark.sql.splitSourcePartition.maxExpandNum")
.doc("The max partition num to split to.")
.version("3.2.0")
.intConf
.checkValue(v => v > 0, "The partition number must be a positive integer.")
.createWithDefault(10)

val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled")
.doc("When true, enable adaptive query execution, which re-optimizes the query plan in the " +
"middle of query execution, based on accurate runtime statistics.")
Expand Down Expand Up @@ -3797,6 +3819,12 @@ class SQLConf extends Serializable with Logging {

def maxCollectSize: Option[Long] = getConf(SQLConf.MAX_COLLECT_SIZE)

def splitSourcePartitionEnabled: Boolean = getConf(SPLIT_SOURCE_PARTITION_ENABLED)

def splitSourcePartitionThreshold: Long = getConf(SPLIT_SOURCE_PARTITION_THRESHOLD)

def splitSourcePartitionMaxExpandNum: Int = getConf(SPLIT_SOURCE_PARTITION_MAXEXPANDNUM)

def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED)

def adaptiveExecutionLogLevel: String = getConf(ADAPTIVE_EXECUTION_LOG_LEVEL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableU
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.split.SplitSourcePartition
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.OutputMode
Expand Down Expand Up @@ -427,6 +428,9 @@ object QueryExecution {
// number of partitions when instantiating PartitioningCollection.
RemoveRedundantSorts,
DisableUnnecessaryBucketedScan,
// `SplitSourcePartition` needs to be added before `ColumnarToRowExec` to avoid handling
// 'ColumnarBatch' output from DataSourceScanExec.
SplitSourcePartition,
ApplyColumnarRulesAndInsertTransitions(
sparkSession.sessionState.columnarRules, outputsColumnar = false),
CollapseCodegenStages()) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,4 +233,9 @@ class FileScanRDD(
expectedTargets
}
}

def partitionFilesTotalLength: Long = {
filePartitions.map(_.files.map(_.length).sum).sum
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.split

import scala.annotation.tailrec
import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
import org.apache.spark.sql.execution.datasources.FileScanRDD
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

case class SplitExec(maxExpandNum: Int, thresholdSize: Long, child: SparkPlan)
extends UnaryExecNode {

/**
* @return All metrics containing metrics of this SparkPlan.
*/
override lazy val metrics: Map[String, SQLMetric] = Map(
"originPartNum" -> SQLMetrics.createMetric(sparkContext, "origin partition num"),
"expandPartNum" -> SQLMetrics.createMetric(sparkContext, "expand partition num"))

/**
* Returns the name of this type of TreeNode. Defaults to the class name.
* Note that we remove the "Exec" suffix for physical operators here.
*/
override def nodeName: String = "SplitSourcePartition"

override def output: Seq[Attribute] = child.output

override protected def withNewChildInternal(newChild: SparkPlan): SplitExec =
copy(child = newChild)

/**
* Produces the result of the query as an `RDD[InternalRow]`
*
* Overridden by concrete implementations of SparkPlan.
*/
override protected def doExecute(): RDD[InternalRow] = {
doSplit(child.execute())
}

private def doSplit[U: ClassTag](prev: RDD[U]): RDD[U] = {
val prevPartNum = prev.getNumPartitions
metrics("originPartNum").set(prevPartNum)
// default: do nothing
metrics("expandPartNum").set(prevPartNum)
val sourceSize = evalSourceSize(prev)
val after = sourceSize
.map { size =>
if (size < thresholdSize) {
// If source size is tiny, split will not be profitable.
prev
} else {
val expandPartNum = maxExpandNum min sparkContext.defaultParallelism
if (expandPartNum < (prevPartNum << 1)) {
// If expansion scale is tiny, split will also not be profitable.
prev
} else {
metrics("expandPartNum").set(expandPartNum)
// Maybe we could find better ways than `coalesce` to redistribute the partition data.
prev.coalesce(expandPartNum, shuffle = true)
}
}
}
.getOrElse(prev)

// update metrics
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq)
after
}

@tailrec
private def evalSourceSize[U: ClassTag](prev: RDD[U]): Option[Long] =
prev match {
case f: FileScanRDD => Some(f.partitionFilesTotalLength)
case r if r.dependencies.isEmpty => None
case other => evalSourceSize(other.firstParent)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.split

import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec

object SplitSourcePartition extends Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = {
if (!plan.conf.splitSourcePartitionEnabled ||
plan.find(_.isInstanceOf[DataSourceScanExec]).isEmpty) {
return plan
}

val r = plan.transformOnceWithPruning(shouldBreak) {
case p if p != null && p.isInstanceOf[DataSourceScanExec] =>
val maxExpandNum = plan.conf.splitSourcePartitionMaxExpandNum
val thresholdSize = plan.conf.splitSourcePartitionThreshold
SplitExec(maxExpandNum, thresholdSize, p)
case other => other
}
r
}

private def shouldBreak(plan: SparkPlan): Boolean =
plan match {
case BroadcastExchangeExec(_, c) => askChild(c)
/* case p if !p.requiredChildDistribution.forall(_ == UnspecifiedDistribution) =>
p.children.exists(supportSplit) */
case _ => false
}

@tailrec
private def askChild(plan: SparkPlan): Boolean =
plan match {
case n if n == null => false
case l: LeafExecNode => l.isInstanceOf[DataSourceScanExec]
case u: UnaryExecNode => askChild(u.child)
case _ => false
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution

import org.apache.spark.sql.{QueryTest, SaveMode}
import org.apache.spark.sql.execution.split.SplitExec
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

class SplitSourcePartitionSuite extends QueryTest with SharedSparkSession {

private val TABLE_FORMAT: String = "parquet"

test("split source partition") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
SQLConf.FILES_MIN_PARTITION_NUM.key -> "1",
SQLConf.SPLIT_SOURCE_PARTITION_ENABLED.key -> "true",
SQLConf.SPLIT_SOURCE_PARTITION_THRESHOLD.key -> "1B") {
withTable("ssp_t1", "ssp_t2") {
spark
.range(10)
.select(col("id"), col("id").as("k"))
.write
.mode(SaveMode.Overwrite)
.format(TABLE_FORMAT)
.saveAsTable("ssp_t1")

spark
.range(5)
.select(col("id"), col("id").as("k"))
.write
.mode(SaveMode.Overwrite)
.format(TABLE_FORMAT)
.saveAsTable("ssp_t2")

val df = sql("""
|SELECT ssp_t1.id, ssp_t2.k
|FROM ssp_t1 INNER JOIN ssp_t2 ON ssp_t1.k = ssp_t2.k
|WHERE ssp_t2.id < 2
|""".stripMargin)

val plan = df.queryExecution.executedPlan
assertResult(1, "SplitExec applied.")(plan.collectWithSubqueries { case e: SplitExec =>
e
}.size)

assertResult(spark.sparkContext.defaultParallelism, "split partitions.")(
df.rdd.partitions.length)
}

}
}

}

0 comments on commit e913aa8

Please sign in to comment.