diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index 16f597cdb34..e8f90eceeb1 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -25,9 +25,9 @@ import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.GlobalLimit import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.HiveResult +import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -295,16 +295,23 @@ object SparkDatasetHelper extends Logging { SQLMetrics.postDriverMetricUpdates(sc, executionId, metrics.values.toSeq) } + private[kyuubi] def optimizedPlanLimit(queryExecution: QueryExecution): Option[Long] = + queryExecution.optimizedPlan match { + case globalLimit: GlobalLimit => globalLimit.maxRows + case _ => None + } + def shouldSaveResultToFs(resultMaxRows: Int, minSize: Long, result: DataFrame): Boolean = { if (isCommandExec(result.queryExecution.executedPlan.nodeName)) { return false } - lazy val limit = result.queryExecution.executedPlan match { - case collectLimit: CollectLimitExec => collectLimit.limit - case _ => resultMaxRows + val finalLimit = optimizedPlanLimit(result.queryExecution) match { + case Some(limit) if resultMaxRows > 0 => math.min(limit, resultMaxRows) + case Some(limit) => limit + case None => resultMaxRows } - lazy val stats = if (limit > 0) { - limit * EstimationUtils.getSizePerRow( + lazy val stats = if (finalLimit > 0) { + finalLimit * EstimationUtils.getSizePerRow( result.queryExecution.executedPlan.output) } else { result.queryExecution.optimizedPlan.stats.sizeInBytes diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala new file mode 100644 index 00000000000..8ac00e60262 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kyuubi + +import org.apache.spark.sql.internal.SQLConf + +import org.apache.kyuubi.engine.spark.WithSparkSQLEngine + +class SparkDatasetHelperSuite extends WithSparkSQLEngine { + override def withKyuubiConf: Map[String, String] = Map.empty + + test("get limit from spark plan") { + Seq(true, false).foreach { aqe => + val topKThreshold = 3 + spark.sessionState.conf.setConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED, aqe) + spark.sessionState.conf.setConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD, topKThreshold) + spark.sql("CREATE OR REPLACE TEMPORARY VIEW tv AS" + + " SELECT * FROM VALUES(1),(2),(3),(4) AS t(id)") + + val topKStatement = s"SELECT * FROM(SELECT * FROM tv ORDER BY id LIMIT ${topKThreshold - 1})" + assert(SparkDatasetHelper.optimizedPlanLimit( + spark.sql(topKStatement).queryExecution) === Option(topKThreshold - 1)) + + val collectLimitStatement = + s"SELECT * FROM (SELECT * FROM tv ORDER BY id LIMIT $topKThreshold)" + assert(SparkDatasetHelper.optimizedPlanLimit( + spark.sql(collectLimitStatement).queryExecution) === Option(topKThreshold)) + } + } +}