diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 9dcb38f8ff10e..c5c2f9bb6a6f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.config.{SPARK_DRIVER_PREFIX, SPARK_EXECUTOR_PRE import org.apache.spark.internal.config.Tests.IS_TESTING import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec +import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, SparkListenerSQLExecutionStart} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.SQL_EVENT_TRUNCATE_LENGTH @@ -178,8 +179,11 @@ object SQLExecution extends Logging { val shuffleIds = queryExecution.executedPlan match { case ae: AdaptiveSparkPlanExec => ae.context.shuffleIds.asScala.keys - case _ => - Iterable.empty + case nonAdaptivePlan => + nonAdaptivePlan.collect { + case exec: ShuffleExchangeLike => + exec.shuffleId + } } shuffleIds.foreach { shuffleId => queryExecution.shuffleCleanupMode match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1aab2a855bb4a..0d14faaf8144a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -328,36 +328,51 @@ class QueryExecutionSuite extends SharedSparkSession { } test("SPARK-47764: Cleanup shuffle dependencies - DoNotCleanup mode") { - val plan = spark.range(100).repartition(10).logicalPlan - val df = Dataset.ofRows(spark, plan, DoNotCleanup) - df.collect() - - val blockManager = spark.sparkContext.env.blockManager - assert(blockManager.migratableResolver.getStoredShuffles().nonEmpty) - assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty) - cleanupShuffles() + Seq(true, false).foreach { adaptiveEnabled => { + withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString)) { + val plan = spark.range(100).repartition(10).logicalPlan + val df = Dataset.ofRows(spark, plan, DoNotCleanup) + df.collect() + + val blockManager = spark.sparkContext.env.blockManager + assert(blockManager.migratableResolver.getStoredShuffles().nonEmpty) + assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty) + cleanupShuffles() + } + } + } } test("SPARK-47764: Cleanup shuffle dependencies - SkipMigration mode") { - val plan = spark.range(100).repartition(10).logicalPlan - val df = Dataset.ofRows(spark, plan, SkipMigration) - df.collect() - - val blockManager = spark.sparkContext.env.blockManager - assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) - assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty) - cleanupShuffles() + Seq(true, false).foreach { adaptiveEnabled => { + withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString)) { + val plan = spark.range(100).repartition(10).logicalPlan + val df = Dataset.ofRows(spark, plan, SkipMigration) + df.collect() + + val blockManager = spark.sparkContext.env.blockManager + assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) + assert(blockManager.diskBlockManager.getAllBlocks().nonEmpty) + cleanupShuffles() + } + } + } } test("SPARK-47764: Cleanup shuffle dependencies - RemoveShuffleFiles mode") { - val plan = spark.range(100).repartition(10).logicalPlan - val df = Dataset.ofRows(spark, plan, RemoveShuffleFiles) - df.collect() - - val blockManager = spark.sparkContext.env.blockManager - assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) - assert(blockManager.diskBlockManager.getAllBlocks().isEmpty) - cleanupShuffles() + Seq(true, false).foreach { adaptiveEnabled => { + withSQLConf((SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, adaptiveEnabled.toString)) { + val plan = spark.range(100).repartition(10).logicalPlan + val df = Dataset.ofRows(spark, plan, RemoveShuffleFiles) + df.collect() + + val blockManager = spark.sparkContext.env.blockManager + assert(blockManager.migratableResolver.getStoredShuffles().isEmpty) + assert(blockManager.diskBlockManager.getAllBlocks().isEmpty) + cleanupShuffles() + } + } + } } test("SPARK-35378: Return UnsafeRow in CommandResultExecCheck execute methods") {