diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 4e963ae378109..2afca2b8e8549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -266,8 +266,4 @@ class FileScanRDD( } } - def partitionFilesTotalLength: Long = { - filePartitions.map(_.files.map(_.length).sum).sum - } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExchangeExec.scala new file mode 100644 index 0000000000000..10cc3885e620b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExchangeExec.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.split + +import scala.annotation.tailrec +import scala.collection.immutable.ListMap +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.Serializer +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.RoundRobinPartitioning +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.FileScanRDD +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} + +case class SplitExchangeExec(maxExpandNum: Int, profitableSize: Long, child: SparkPlan) + extends UnaryExecNode { + + /** + * @return All metrics containing metrics of this SparkPlan. + */ + override lazy val metrics: Map[String, SQLMetric] = ListMap( + "numPartitions" -> SQLMetrics.createMetric(sparkContext, "number of partitions"), + "originalNumPartitions" -> SQLMetrics.createMetric( + sparkContext, + "original number of partitions"), + "dataSize" -> SQLMetrics + .createSizeMetric(sparkContext, "data size")) ++ readMetrics ++ writeMetrics + + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + + /** + * Returns the name of this type of TreeNode. Defaults to the class name. + * Note that we remove the "Exec" suffix for physical operators here. + */ + override def nodeName: String = "SplitExchange" + + override def output: Seq[Attribute] = child.output + + /** + * The arguments that should be included in the arg string. Defaults to the `productIterator`. + */ + override protected def stringArgs: Iterator[Any] = super.stringArgs + + override protected def withNewChildInternal(newChild: SparkPlan): SplitExchangeExec = + copy(child = newChild) + + /** + * Produces the result of the query as an `RDD[InternalRow]` + * + * Overridden by concrete implementations of SparkPlan. + */ + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + val numPartitions = inputRDD.getNumPartitions + metrics("numPartitions").set(numPartitions) + metrics("originalNumPartitions").set(numPartitions) + + val expandPartNum = maxExpandNum min session.leafNodeDefaultParallelism + + val splitRDD = if (expandPartNum < (numPartitions << 1)) { + inputRDD + } else { + val sourceSize = evalSourceSize(inputRDD).getOrElse(-1L) + if (sourceSize < profitableSize) { + inputRDD + } else { + metrics("numPartitions").set(expandPartNum) + val serializer: Serializer = + new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + val shuffleDependency = + ShuffleExchangeExec.prepareShuffleDependency( + inputRDD, + child.output, + RoundRobinPartitioning(expandPartNum), + serializer, + writeMetrics) + new ShuffledRowRDD(shuffleDependency, readMetrics) + } + } + + // update metrics + val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) + SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) + + splitRDD + } + + @tailrec + private def evalSourceSize[U: ClassTag](prev: RDD[U]): Option[Long] = + prev match { + case f: FileScanRDD => Some(f.filePartitions.map(_.files.map(_.length).sum).sum) + case r if r.dependencies.isEmpty => None + case o => evalSourceSize(o.firstParent) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExec.scala deleted file mode 100644 index c19ef609675f8..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitExec.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.split - -import scala.annotation.tailrec -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode} -import org.apache.spark.sql.execution.datasources.FileScanRDD -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} - -case class SplitExec(maxExpandNum: Int, thresholdSize: Long, child: SparkPlan) - extends UnaryExecNode { - - /** - * @return All metrics containing metrics of this SparkPlan. - */ - override lazy val metrics: Map[String, SQLMetric] = Map( - "originPartNum" -> SQLMetrics.createMetric(sparkContext, "origin partition num"), - "expandPartNum" -> SQLMetrics.createMetric(sparkContext, "expand partition num")) - - /** - * Returns the name of this type of TreeNode. Defaults to the class name. - * Note that we remove the "Exec" suffix for physical operators here. - */ - override def nodeName: String = "SplitSourcePartition" - - override def output: Seq[Attribute] = child.output - - override protected def withNewChildInternal(newChild: SparkPlan): SplitExec = - copy(child = newChild) - - /** - * Produces the result of the query as an `RDD[InternalRow]` - * - * Overridden by concrete implementations of SparkPlan. - */ - override protected def doExecute(): RDD[InternalRow] = { - doSplit(child.execute()) - } - - private def doSplit[U: ClassTag](prev: RDD[U]): RDD[U] = { - val prevPartNum = prev.getNumPartitions - metrics("originPartNum").set(prevPartNum) - // default: do nothing - metrics("expandPartNum").set(prevPartNum) - val sourceSize = evalSourceSize(prev) - val after = sourceSize - .map { size => - if (size < thresholdSize) { - // If source size is tiny, split will not be profitable. - prev - } else { - val expandPartNum = maxExpandNum min sparkContext.defaultParallelism - if (expandPartNum < (prevPartNum << 1)) { - // If expansion scale is tiny, split will also not be profitable. - prev - } else { - metrics("expandPartNum").set(expandPartNum) - // Maybe we could find better ways than `coalesce` to redistribute the partition data. - prev.coalesce(expandPartNum, shuffle = true) - } - } - } - .getOrElse(prev) - - // update metrics - val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) - SQLMetrics.postDriverMetricUpdates(sparkContext, executionId, metrics.values.toSeq) - after - } - - @tailrec - private def evalSourceSize[U: ClassTag](prev: RDD[U]): Option[Long] = - prev match { - case f: FileScanRDD => Some(f.partitionFilesTotalLength) - case r if r.dependencies.isEmpty => None - case other => evalSourceSize(other.firstParent) - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitSourcePartition.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitSourcePartition.scala index 2ef989aa561be..18dc1e3eeeee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitSourcePartition.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/split/SplitSourcePartition.scala @@ -27,16 +27,20 @@ object SplitSourcePartition extends Rule[SparkPlan] { override def apply(plan: SparkPlan): SparkPlan = { if (!plan.conf.splitSourcePartitionEnabled || - plan.find(_.isInstanceOf[DataSourceScanExec]).isEmpty) { + plan.find { + case _: DataSourceScanExec => true + case _ => false + }.isEmpty) { return plan } val r = plan.transformOnceWithPruning(shouldBreak) { - case p if p != null && p.isInstanceOf[DataSourceScanExec] => - val maxExpandNum = plan.conf.splitSourcePartitionMaxExpandNum - val thresholdSize = plan.conf.splitSourcePartitionThreshold - SplitExec(maxExpandNum, thresholdSize, p) - case other => other + case d: DataSourceScanExec => + SplitExchangeExec( + plan.conf.splitSourcePartitionMaxExpandNum, + plan.conf.splitSourcePartitionThreshold, + d) + case o => o } r } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SplitSourcePartitionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SplitSourcePartitionSuite.scala index 41dce60a9a335..627e7dafb691a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SplitSourcePartitionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SplitSourcePartitionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{QueryTest, SaveMode} -import org.apache.spark.sql.execution.split.SplitExec +import org.apache.spark.sql.execution.split.SplitExchangeExec import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -57,8 +57,9 @@ class SplitSourcePartitionSuite extends QueryTest with SharedSparkSession { |""".stripMargin) val plan = df.queryExecution.executedPlan - assertResult(1, "SplitExec applied.")(plan.collectWithSubqueries { case e: SplitExec => - e + assertResult(1, "SplitExchangeExec applied.")(plan.collectWithSubqueries { + case e: SplitExchangeExec => + e }.size) assertResult(spark.sparkContext.defaultParallelism, "split partitions.")(