diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e2b9cc65ebded..46a90f600b2a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -539,17 +539,22 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit. case LocalLimit(exp, u: Union) => LocalLimit(exp, u.copy(children = u.children.map(maybePushLocalLimit(exp, _)))) - // Add extra limits below OUTER JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to - // the left and right sides, respectively. It's not safe to push limits below FULL OUTER - // JOIN in the general case without a more invasive rewrite. + // Add extra limits below JOIN. For LEFT OUTER and RIGHT OUTER JOIN we push limits to + // the left and right sides, respectively. For INNER and CROSS JOIN we push limits to + // both the left and right sides if join condition is empty. It's not safe to push limits + // below FULL OUTER JOIN in the general case without a more invasive rewrite. // We also need to ensure that this limit pushdown rule will not eventually introduce limits // on both sides if it is applied multiple times. Therefore: // - If one side is already limited, stack another limit on top if the new limit is smaller. // The redundant limit will be collapsed by the CombineLimits rule. - case LocalLimit(exp, join @ Join(left, right, joinType, _, _)) => + case LocalLimit(exp, join @ Join(left, right, joinType, conditionOpt, _)) => val newJoin = joinType match { case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) + case _: InnerLike if conditionOpt.isEmpty => + join.copy( + left = maybePushLocalLimit(exp, left), + right = maybePushLocalLimit(exp, right)) case _ => join } LocalLimit(exp, newJoin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index bb23b63c03cea..5c760264ff219 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Add -import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, PlanTest, RightOuter} +import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -194,4 +194,22 @@ class LimitPushdownSuite extends PlanTest { LocalLimit(1, y.groupBy(Symbol("b"))(count(1))))).analyze comparePlans(expected2, optimized2) } + + test("SPARK-26138: pushdown limit through InnerLike when condition is empty") { + Seq(Cross, Inner).foreach { joinType => + val originalQuery = x.join(y, joinType).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, LocalLimit(1, x).join(LocalLimit(1, y), joinType)).analyze + comparePlans(optimized, correctAnswer) + } + } + + test("SPARK-26138: Should not pushdown limit through InnerLike when condition is not empty") { + Seq(Cross, Inner).foreach { joinType => + val originalQuery = x.join(y, joinType, Some("x.a".attr === "y.b".attr)).limit(1) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = Limit(1, x.join(y, joinType, Some("x.a".attr === "y.b".attr))).analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index e6a18c3894497..fe8a080ac5aeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -29,14 +29,14 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} -import org.apache.spark.sql.catalyst.plans.logical.{Project, RepartitionByExpression} +import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.UnionExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.FunctionsCommand -import org.apache.spark.sql.execution.datasources.SchemaColumnConvertNotSupportedException +import org.apache.spark.sql.execution.datasources.{LogicalRelation, SchemaColumnConvertNotSupportedException} import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.orc.OrcScan import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan @@ -4021,6 +4021,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-26138 Pushdown limit through InnerLike when condition is empty") { + withTable("t1", "t2") { + spark.range(5).repartition(1).write.saveAsTable("t1") + spark.range(5).repartition(1).write.saveAsTable("t2") + val df = spark.sql("SELECT * FROM t1 CROSS JOIN t2 LIMIT 3") + val pushedLocalLimits = df.queryExecution.optimizedPlan.collect { + case l @ LocalLimit(_, _: LogicalRelation) => l + } + assert(pushedLocalLimits.length === 2) + checkAnswer(df, Row(0, 0) :: Row(0, 1) :: Row(0, 2) :: Nil) + } + } } case class Foo(bar: Option[String])