Skip to content

Commit eae5ca7

Browse files
summaryzbLuciferYang
authored andcommitted
[SPARK-51704][SQL] Eliminate unnecessary collect operation
### What changes were proposed in this pull request? Change operation for `TreeNode` from `collect` to `collectFirst` when in below scenarios: - The final purpose is to find the first qualified node in a pre order way - The final purpose is to verify that there is no one or at least one node which satisfy the requirement Two factors should be satisfied: - Apply certain operation recursively on the `TreeNode` including itself - Partial function applied should not effect the node or other related object ### Why are the changes needed? Avoid unnecessary Traversal of the `TreeNode` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #50494 from summaryzb/sql_collect. Authored-by: summaryzb <[email protected]> Signed-off-by: yangjie01 <[email protected]>
1 parent 9d3f937 commit eae5ca7

File tree

12 files changed

+33
-31
lines changed

12 files changed

+33
-31
lines changed

connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1614,7 +1614,7 @@ abstract class KafkaMicroBatchV1SourceSuite extends KafkaMicroBatchSourceSuiteBa
16141614
testStream(kafka)(
16151615
makeSureGetOffsetCalled,
16161616
AssertOnQuery { query =>
1617-
query.logicalPlan.collect {
1617+
query.logicalPlan.collectFirst {
16181618
case StreamingExecutionRelation(_: KafkaSource, _, _) => true
16191619
}.nonEmpty
16201620
}

connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ class KafkaRelationSuiteV1 extends KafkaRelationSuiteBase {
637637
test("V1 Source is used when set through SQLConf") {
638638
val topic = newTopic()
639639
val df = createDF(topic)
640-
assert(df.logicalPlan.collect {
640+
assert(df.logicalPlan.collectFirst {
641641
case _: LogicalRelation => true
642642
}.nonEmpty)
643643
}
@@ -652,7 +652,7 @@ class KafkaRelationSuiteV2 extends KafkaRelationSuiteBase {
652652
test("V2 Source is used when set through SQLConf") {
653653
val topic = newTopic()
654654
val df = createDF(topic)
655-
assert(df.logicalPlan.collect {
655+
assert(df.logicalPlan.collectFirst {
656656
case _: DataSourceV2Relation => true
657657
}.nonEmpty)
658658
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
18331833
* Returns true if `exprs` contains a [[Star]].
18341834
*/
18351835
def containsStar(exprs: Seq[Expression]): Boolean =
1836-
exprs.exists(_.collect { case _: Star => true }.nonEmpty)
1836+
exprs.exists(_.collectFirst { case _: Star => true }.nonEmpty)
18371837

18381838
private def extractStar(exprs: Seq[Expression]): Seq[Star] =
18391839
exprs.flatMap(_.collect { case s: Star => s })

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ abstract class ProgressContext(
559559
hasNewData: Boolean,
560560
sourceToNumInputRows: Map[SparkDataStream, Long],
561561
lastExecution: IncrementalExecution): ExecutionStats = {
562-
val hasEventTime = progressReporter.logicalPlan().collect {
562+
val hasEventTime = progressReporter.logicalPlan().collectFirst {
563563
case e: EventTimeWatermark => e
564564
}.nonEmpty
565565

sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ abstract class CTEInlineSuiteBase
170170
}.head.length == 2,
171171
"With-CTE should contain 2 CTE defs after analysis.")
172172
assert(
173-
df.queryExecution.optimizedPlan.collect {
173+
df.queryExecution.optimizedPlan.collectFirst {
174174
case r: RepartitionOperation => r
175175
}.isEmpty,
176176
"CTEs with one reference should all be inlined after optimization.")
@@ -255,7 +255,7 @@ abstract class CTEInlineSuiteBase
255255
}.head.length == 2,
256256
"With-CTE should contain 2 CTE defs after analysis.")
257257
assert(
258-
df.queryExecution.optimizedPlan.collect {
258+
df.queryExecution.optimizedPlan.collectFirst {
259259
case r: RepartitionOperation => r
260260
}.isEmpty,
261261
"Deterministic CTEs should all be inlined after optimization.")

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
508508
|join abc c on a.key=c.key""".stripMargin).queryExecution.sparkPlan
509509

510510
assert(sparkPlan.collect { case e: InMemoryTableScanExec => e }.size === 3)
511-
assert(sparkPlan.collect { case e: RDDScanExec => e }.size === 0)
511+
assert(sparkPlan.collectFirst { case e: RDDScanExec => e }.isEmpty)
512512
}
513513
}
514514

@@ -923,7 +923,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils
923923
withSQLConf(SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> vectorized.toString) {
924924
val cache = spark.range(10).cache()
925925
val df = cache.filter($"id" > 0)
926-
val columnarToRow = df.queryExecution.executedPlan.collect {
926+
val columnarToRow = df.queryExecution.executedPlan.collectFirst {
927927
case c: ColumnarToRowExec => c
928928
}
929929
assert(columnarToRow.isEmpty)

sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ class DatasetOptimizationSuite extends QueryTest with SharedSparkSession {
3232
test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
3333
val data = Seq(("a", 1), ("b", 2), ("c", 3))
3434
val ds = data.toDS().map(t => (t._1, t._2 + 1)).select("_1")
35-
val serializer = ds.queryExecution.optimizedPlan.collect {
35+
val serializerOpt = ds.queryExecution.optimizedPlan.collectFirst {
3636
case s: SerializeFromObject => s
37-
}.head
38-
assert(serializer.serializer.size == 1)
37+
}
38+
assert(serializerOpt.isDefined)
39+
assert(serializerOpt.get.serializer.size == 1)
3940
checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
4041
}
4142

@@ -45,15 +46,16 @@ class DatasetOptimizationSuite extends QueryTest with SharedSparkSession {
4546
// serializers. The first `structFields` is aligned with first serializer and ditto
4647
// for other `structFields`.
4748
private def testSerializer(df: DataFrame, structFields: Seq[Seq[String]]*): Unit = {
48-
val serializer = df.queryExecution.optimizedPlan.collect {
49+
val serializerOpt = df.queryExecution.optimizedPlan.collectFirst {
4950
case s: SerializeFromObject => s
50-
}.head
51+
}
5152

5253
def collectNamedStruct: PartialFunction[Expression, Seq[CreateNamedStruct]] = {
5354
case c: CreateNamedStruct => Seq(c)
5455
}
5556

56-
serializer.serializer.zip(structFields).foreach { case (ser, fields) =>
57+
assert(serializerOpt.isDefined)
58+
serializerOpt.get.serializer.zip(structFields).foreach { case (ser, fields) =>
5759
val structs: Seq[CreateNamedStruct] = ser.collect(collectNamedStruct).flatten
5860
assert(structs.size == fields.size)
5961
structs.zip(fields).foreach { case (struct, fieldNames) =>

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ object QueryTest extends Assertions {
285285
df: DataFrame,
286286
expectedAnswer: Seq[Row],
287287
checkToRDD: Boolean = true): Option[String] = {
288-
val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
288+
val isSorted = df.logicalPlan.collectFirst { case s: logical.Sort => s }.nonEmpty
289289
if (checkToRDD) {
290290
SQLExecution.withSQLConfPropagated(df.sparkSession) {
291291
df.materializedRdd.count() // Also attempt to deserialize as an RDD [SPARK-15791]

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -821,11 +821,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
821821
case cp: CartesianProductExec => cp
822822
}
823823
assert(cp.isEmpty, "should not use CartesianProduct for null-safe join")
824-
val smj = df.queryExecution.sparkPlan.collect {
824+
val smj = df.queryExecution.sparkPlan.collectFirst {
825825
case smj: SortMergeJoinExec => smj
826826
case j: BroadcastHashJoinExec => j
827827
}
828-
assert(smj.size > 0, "should use SortMergeJoin or BroadcastHashJoin")
828+
assert(smj.nonEmpty, "should use SortMergeJoin or BroadcastHashJoin")
829829
checkAnswer(df, Row(100) :: Nil)
830830
}
831831

@@ -3815,7 +3815,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
38153815
Seq(1, "1, 2", null, "version()").foreach { expr =>
38163816
val plan = sql(s"select * from values (1), (2), (3) t(a) distribute by $expr")
38173817
.queryExecution.optimizedPlan
3818-
val res = plan.collect {
3818+
val res = plan.collectFirst {
38193819
case r: RepartitionByExpression if r.numPartitions == 1 => true
38203820
}
38213821
assert(res.nonEmpty)
@@ -3827,7 +3827,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
38273827
withSQLConf((SQLConf.SHUFFLE_PARTITIONS.key, "5")) {
38283828
val df = spark.range(1).hint("REPARTITION_BY_RANGE")
38293829
val plan = df.queryExecution.optimizedPlan
3830-
val res = plan.collect {
3830+
val res = plan.collectFirst {
38313831
case r: RepartitionByExpression if r.numPartitions == 5 => true
38323832
}
38333833
assert(res.nonEmpty)
@@ -3839,7 +3839,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
38393839
Seq(1, "1, 2", null, "version()").foreach { expr =>
38403840
val plan = sql(s"select * from values (1), (2), (3) t(a) distribute by $expr")
38413841
.queryExecution.analyzed
3842-
val res = plan.collect {
3842+
val res = plan.collectFirst {
38433843
case r: RepartitionByExpression if r.numPartitions == 2 => true
38443844
}
38453845
assert(res.nonEmpty)

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2722,7 +2722,7 @@ class SubquerySuite extends QueryTest
27222722
|SELECT * FROM v1 WHERE kind = (SELECT kind FROM v1 WHERE kind = 'foo')
27232723
|""".stripMargin)
27242724
val df = sql("SELECT * FROM v1 JOIN v2 ON v1.id = v2.id")
2725-
val filter = df.queryExecution.optimizedPlan.collect {
2725+
val filter = df.queryExecution.optimizedPlan.collectFirst {
27262726
case f: Filter => f
27272727
}
27282728
assert(filter.isEmpty,

0 commit comments

Comments
 (0)