From 1c781a4354666bba4329e588a0e9a9fa8980303b Mon Sep 17 00:00:00 2001 From: Terry Kim Date: Thu, 8 Oct 2020 04:58:41 +0000 Subject: [PATCH] [SPARK-32282][SQL] Improve EnsureRquirement.reorderJoinKeys to handle more scenarios such as PartitioningCollection ### What changes were proposed in this pull request? This PR proposes to improve `EnsureRquirement.reorderJoinKeys` to handle the following scenarios: 1. If the keys cannot be reordered to match the left-side `HashPartitioning`, consider the right-side `HashPartitioning`. 2. Handle `PartitioningCollection`, which may contain `HashPartitioning` ### Why are the changes needed? 1. For the scenario 1), the current behavior matches either the left-side `HashPartitioning` or the right-side `HashPartitioning`. This means that if both sides are `HashPartitioning`, it will try to match only the left side. The following will not consider the right-side `HashPartitioning`: ``` val df1 = (0 until 10).map(i => (i % 5, i % 13)).toDF("i1", "j1") val df2 = (0 until 10).map(i => (i % 7, i % 11)).toDF("i2", "j2") df1.write.format("parquet").bucketBy(4, "i1", "j1").saveAsTable("t1")df2.write.format("parquet").bucketBy(4, "i2", "j2").saveAsTable("t2") val t1 = spark.table("t1") val t2 = spark.table("t2") val join = t1.join(t2, t1("i1") === t2("j2") && t1("i1") === t2("i2")) join.explain == Physical Plan == *(5) SortMergeJoin [i1#26, i1#26], [j2#31, i2#30], Inner :- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#69] : +- *(1) Project [i1#26, j1#27] : +- *(1) Filter isnotnull(i1#26) : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct, SelectedBucketsCount: 4 out of 4 +- *(4) Sort [j2#31 ASC NULLS FIRST, i2#30 ASC NULLS FIRST], false, 0. +- Exchange hashpartitioning(j2#31, i2#30, 4), true, [id=#79]. <===== This can be removed +- *(3) Project [i2#30, j2#31] +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30)) +- *(3) ColumnarToRow +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct, SelectedBucketsCount: 4 out of 4 ``` 2. For the scenario 2), the current behavior does not handle `PartitioningCollection`: ``` val df1 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i1", "j1") val df2 = (0 until 100).map(i => (i % 7, i % 11)).toDF("i2", "j2") val df3 = (0 until 100).map(i => (i % 5, i % 13)).toDF("i3", "j3") val join = df1.join(df2, df1("i1") === df2("i2") && df1("j1") === df2("j2")) // PartitioningCollection val join2 = join.join(df3, join("j1") === df3("j3") && join("i1") === df3("i3")) join2.explain == Physical Plan == *(9) SortMergeJoin [j1#8, i1#7], [j3#30, i3#29], Inner :- *(6) Sort [j1#8 ASC NULLS FIRST, i1#7 ASC NULLS FIRST], false, 0. <===== This can be removed : +- Exchange hashpartitioning(j1#8, i1#7, 5), true, [id=#58] <===== This can be removed : +- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner : :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#45] : : +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8] : : +- *(1) LocalTableScan [_1#2, _2#3] : +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#51] : +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19] : +- *(3) LocalTableScan [_1#13, _2#14] +- *(8) Sort [j3#30 ASC NULLS FIRST, i3#29 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(j3#30, i3#29, 5), true, [id=#64] +- *(7) Project [_1#24 AS i3#29, _2#25 AS j3#30] +- *(7) LocalTableScan [_1#24, _2#25] ``` ### Does this PR introduce _any_ user-facing change? Yes, now from the above examples, the shuffle/sort nodes pointed by `This can be removed` are now removed: 1. Senario 1): ``` == Physical Plan == *(4) SortMergeJoin [i1#26, i1#26], [i2#30, j2#31], Inner :- *(2) Sort [i1#26 ASC NULLS FIRST, i1#26 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i1#26, i1#26, 4), true, [id=#67] : +- *(1) Project [i1#26, j1#27] : +- *(1) Filter isnotnull(i1#26) : +- *(1) ColumnarToRow : +- FileScan parquet default.t1[i1#26,j1#27] Batched: true, DataFilters: [isnotnull(i1#26)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(i1)], ReadSchema: struct, SelectedBucketsCount: 4 out of 4 +- *(3) Sort [i2#30 ASC NULLS FIRST, j2#31 ASC NULLS FIRST], false, 0 +- *(3) Project [i2#30, j2#31] +- *(3) Filter (((j2#31 = i2#30) AND isnotnull(j2#31)) AND isnotnull(i2#30)) +- *(3) ColumnarToRow +- FileScan parquet default.t2[i2#30,j2#31] Batched: true, DataFilters: [(j2#31 = i2#30), isnotnull(j2#31), isnotnull(i2#30)], Format: Parquet, Location: InMemoryFileIndex[..., PartitionFilters: [], PushedFilters: [IsNotNull(j2), IsNotNull(i2)], ReadSchema: struct, SelectedBucketsCount: 4 out of 4 ``` 2. Scenario 2): ``` == Physical Plan == *(8) SortMergeJoin [i1#7, j1#8], [i3#29, j3#30], Inner :- *(5) SortMergeJoin [i1#7, j1#8], [i2#18, j2#19], Inner : :- *(2) Sort [i1#7 ASC NULLS FIRST, j1#8 ASC NULLS FIRST], false, 0 : : +- Exchange hashpartitioning(i1#7, j1#8, 5), true, [id=#43] : : +- *(1) Project [_1#2 AS i1#7, _2#3 AS j1#8] : : +- *(1) LocalTableScan [_1#2, _2#3] : +- *(4) Sort [i2#18 ASC NULLS FIRST, j2#19 ASC NULLS FIRST], false, 0 : +- Exchange hashpartitioning(i2#18, j2#19, 5), true, [id=#49] : +- *(3) Project [_1#13 AS i2#18, _2#14 AS j2#19] : +- *(3) LocalTableScan [_1#13, _2#14] +- *(7) Sort [i3#29 ASC NULLS FIRST, j3#30 ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(i3#29, j3#30, 5), true, [id=#58] +- *(6) Project [_1#24 AS i3#29, _2#25 AS j3#30] +- *(6) LocalTableScan [_1#24, _2#25] ``` ### How was this patch tested? Added tests. Closes #29074 from imback82/reorder_keys. Authored-by: Terry Kim Signed-off-by: Wenchen Fan --- .../exchange/EnsureRequirements.scala | 58 +++++++-- .../exchange/EnsureRequirementsSuite.scala | 122 ++++++++++++++++++ 2 files changed, 168 insertions(+), 12 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b176598ed8c2c..3641654b89b76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -135,9 +135,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftKeys: IndexedSeq[Expression], rightKeys: IndexedSeq[Expression], expectedOrderOfKeys: Seq[Expression], - currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + currentOrderOfKeys: Seq[Expression]): Option[(Seq[Expression], Seq[Expression])] = { if (expectedOrderOfKeys.size != currentOrderOfKeys.size) { - return (leftKeys, rightKeys) + return None + } + + // Check if the current order already satisfies the expected order. + if (expectedOrderOfKeys.zip(currentOrderOfKeys).forall(p => p._1.semanticEquals(p._2))) { + return Some(leftKeys, rightKeys) } // Build a lookup between an expression and the positions its holds in the current key seq. @@ -164,10 +169,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { rightKeysBuffer += rightKeys(index) case _ => // The expression cannot be found, or we have exhausted all indices for that expression. - return (leftKeys, rightKeys) + return None } } - (leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) + Some(leftKeysBuffer.toSeq, rightKeysBuffer.toSeq) } private def reorderJoinKeys( @@ -176,19 +181,48 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { leftPartitioning: Partitioning, rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - (leftPartitioning, rightPartitioning) match { - case (HashPartitioning(leftExpressions, _), _) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) - case (_, HashPartitioning(rightExpressions, _)) => - reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) - case _ => - (leftKeys, rightKeys) - } + reorderJoinKeysRecursively( + leftKeys, + rightKeys, + Some(leftPartitioning), + Some(rightPartitioning)) + .getOrElse((leftKeys, rightKeys)) } else { (leftKeys, rightKeys) } } + /** + * Recursively reorders the join keys based on partitioning. It starts reordering the + * join keys to match HashPartitioning on either side, followed by PartitioningCollection. + */ + private def reorderJoinKeysRecursively( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + leftPartitioning: Option[Partitioning], + rightPartitioning: Option[Partitioning]): Option[(Seq[Expression], Seq[Expression])] = { + (leftPartitioning, rightPartitioning) match { + case (Some(HashPartitioning(leftExpressions, _)), _) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leftExpressions, leftKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(HashPartitioning(rightExpressions, _))) => + reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys) + .orElse(reorderJoinKeysRecursively( + leftKeys, rightKeys, leftPartitioning, None)) + case (Some(PartitioningCollection(partitionings)), _) => + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, Some(p), rightPartitioning)) + }.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, None, rightPartitioning)) + case (_, Some(PartitioningCollection(partitionings))) => + partitionings.foldLeft(Option.empty[(Seq[Expression], Seq[Expression])]) { (res, p) => + res.orElse(reorderJoinKeysRecursively(leftKeys, rightKeys, leftPartitioning, Some(p))) + }.orElse(None) + case _ => + None + } + } + /** * When the physical operators are created for JOIN, the ordering of join keys is based on order * in which the join keys appear in the user query. That might not match with the output diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala new file mode 100644 index 0000000000000..38e68cd2512e7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/exchange/EnsureRequirementsSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.exchange + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection} +import org.apache.spark.sql.execution.{DummySparkPlan, SortExec} +import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.test.SharedSparkSession + +class EnsureRequirementsSuite extends SharedSparkSession { + private val exprA = Literal(1) + private val exprB = Literal(2) + private val exprC = Literal(3) + + test("reorder should handle PartitioningCollection") { + val plan1 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq( + HashPartitioning(exprA :: exprB :: Nil, 5), + HashPartitioning(exprA :: Nil, 5)))) + val plan2 = DummySparkPlan() + + // Test PartitioningCollection on the left side of join. + val smjExec1 = SortMergeJoinExec( + exprB :: exprA :: Nil, exprA :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprA, exprB)) + assert(rightKeys === Seq(exprB, exprA)) + case other => fail(other.toString) + } + + // Test PartitioningCollection on the right side of join. + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprB :: exprA :: Nil, Inner, None, plan2, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(other.toString) + } + + // Both sides are PartitioningCollection, but left side cannot be reorderd to match + // and it should fall back to the right side. + val smjExec3 = SortMergeJoinExec( + exprA :: exprC :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprC, exprA)) + assert(rightKeys === Seq(exprA, exprB)) + case other => fail(other.toString) + } + } + + test("reorder should fallback to the other side partitioning") { + val plan1 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprA :: exprB :: exprC :: Nil, 5)) + val plan2 = DummySparkPlan( + outputPartitioning = HashPartitioning(exprB :: exprC :: Nil, 5)) + + // Test fallback to the right side, which has HashPartitioning. + val smjExec1 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan2) + EnsureRequirements(spark.sessionState.conf).apply(smjExec1) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) + case other => fail(other.toString) + } + + // Test fallback to the right side, which has PartitioningCollection. + val plan3 = DummySparkPlan( + outputPartitioning = PartitioningCollection(Seq(HashPartitioning(exprB :: exprC :: Nil, 5)))) + val smjExec2 = SortMergeJoinExec( + exprA :: exprB :: Nil, exprC :: exprB :: Nil, Inner, None, plan1, plan3) + EnsureRequirements(spark.sessionState.conf).apply(smjExec2) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprA)) + assert(rightKeys === Seq(exprB, exprC)) + case other => fail(other.toString) + } + + // The right side has HashPartitioning, so it is matched first, but no reordering match is + // found, and it should fall back to the left side, which has a PartitioningCollection. + val smjExec3 = SortMergeJoinExec( + exprC :: exprB :: Nil, exprA :: exprB :: Nil, Inner, None, plan3, plan1) + EnsureRequirements(spark.sessionState.conf).apply(smjExec3) match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, + SortExec(_, _, DummySparkPlan(_, _, _: PartitioningCollection, _, _), _), + SortExec(_, _, ShuffleExchangeExec(_: HashPartitioning, _, _), _), _) => + assert(leftKeys === Seq(exprB, exprC)) + assert(rightKeys === Seq(exprB, exprA)) + case other => fail(other.toString) + } + } +}