From 0a5498052a1a992ec292293cf391309691ab7b8c Mon Sep 17 00:00:00 2001 From: Tom van Bussel Date: Fri, 29 Sep 2023 16:28:56 +0200 Subject: [PATCH] Inline non-correlated scalar subqueries in Merge --- .../sql/delta/PreprocessTableMerge.scala | 47 +++++++- .../MergeIntoMaterializeSourceSuite.scala | 43 ++++++- .../spark/sql/delta/MergeIntoSuiteBase.scala | 107 ++++++++++++++++++ 3 files changed, 193 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala index 9e17bbe5819..cdaf41b32fe 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/PreprocessTableMerge.scala @@ -20,12 +20,11 @@ import java.time.{Instant, LocalDateTime} import java.util.Locale import scala.collection.mutable -import scala.reflect.ClassTag import org.apache.spark.sql.delta.commands.MergeIntoCommand import org.apache.spark.sql.delta.sources.DeltaSQLConf -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -35,6 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros} +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampNTZType, TimestampType} @@ -200,11 +200,15 @@ case class PreprocessTableMerge(override val conf: SQLConf) * invoking the [[ComputeCurrentTime]] rule. This is why they need special handling. */ val now = Instant.now() + + val sourceWithInlinedSubqueries = + inlineSubqueryResults(SparkSession.active, transformTimestamps(source, now)) + // Transform timestamps for the MergeIntoCommand, source, and target using the same instant. // Called explicitly because source and target are not children of MergeIntoCommand. transformTimestamps( MergeIntoCommand( - transformTimestamps(source, now), + sourceWithInlinedSubqueries, transformTimestamps(target, now), relation.catalogTable, tahoeFileIndex, @@ -227,6 +231,43 @@ case class PreprocessTableMerge(override val conf: SQLConf) } } + /** + * Inlines the results of the subqueries in the `source` of the `MergeIntoCommand`. + * This is necessary to deal with "deterministic" scalar subqueries that can return + * non-deterministic results. E.g. a query with a LIMIT 1 without an ORDER BY. + * In most cases these subqueries are evaluated only once as part of the source materialization, + * but scalar subqueries can be inferred from the materialized source and propagated to the target + * side of the join. + */ + private def inlineSubqueryResults(spark: SparkSession, source: LogicalPlan): LogicalPlan = { + // Gather all non-correlated scalar subqueries in the source. + val subqueries = source.flatMap { + _.expressions.flatMap(_.collect { case s: ScalarSubquery if !s.isCorrelated => s }) + } + if (subqueries.isEmpty) { + return source + } + + // Evaluate all non-correlated scalar subqueries in a single query to enable subquery reuse. + val namedSubqueries = subqueries.map { s => + Alias(s, s"subquery-${s.exprId.id}")() + } + val qe = new QueryExecution(spark, Project(namedSubqueries, OneRowRelation())) + val result = SQLExecution.withNewExecutionId(qe) { + qe.executedPlan.executeCollect().head + } + + // Replace the subqueries in the source and target with their results. + val subqueryResults = subqueries.zipWithIndex.map { case (s, i) => + s.exprId.id -> Literal.create(result.get(i, s.dataType), s.dataType) + }.toMap + val newSource = source.transformAllExpressions { + case s: ScalarSubquery if !s.isCorrelated => subqueryResults(s.exprId.id) + } + + newSource + } + private def transformTimestamps(plan: LogicalPlan, instant: Instant): LogicalPlan = { import org.apache.spark.sql.delta.implicits._ diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala index 6fb92c10e3a..9b6e615c8ac 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoMaterializeSourceSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.delta.util.JsonUtils import org.scalactic.source.Position import org.scalatest.Tag -import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.{SparkConf, SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical._ @@ -770,6 +770,47 @@ trait MergeIntoMaterializeSourceTests assert(stats.materializeSourceReason.isDefined) stats.materializeSourceReason.get } + + test("merge with non-deterministic scalar subquery in source") { + { + val numRows = 1000 + + val chooseOneRow = udf { i: Long => + val stageId = TaskContext.get().stageId() + stageId % numRows == i + } + assert(chooseOneRow.deterministic) + + // Create a subquery that returns a different row every time it is executed. + val subqueryViewName = "subquery_view" + spark.range(start = 0, end = numRows, step = 1, numPartitions = 1) + .filter(chooseOneRow(col("id"))) + .createTempView(subqueryViewName) + + val targetTableName = "target_table" + val sourceTableName = "source_table" + withTable(targetTableName, sourceTableName) { + spark.range(numRows).select(col("id").as("key"), col("id").as("value")) + .write.mode("overwrite").format("delta").saveAsTable(targetTableName) + spark.sql(s"SELECT key, value + $numRows AS value FROM $targetTableName") + .write.mode("overwrite").format("delta").saveAsTable(sourceTableName) + + spark.sql( + s"""MERGE INTO $targetTableName t + |USING (SELECT * FROM $sourceTableName WHERE key = (SELECT * FROM $subqueryViewName)) s + |ON t.key = s.key + |WHEN MATCHED THEN UPDATE SET * + |WHEN NOT MATCHED THEN INSERT *""".stripMargin + ) + + // No new rows should have been inserted, as all keys in the source are already present in + // the target. If the subquery is evaluated multiple times, however, then the source may + // return different rows when finding the touched files and when writing the modified rows, + // in which case an update may be incorrectly treated as an insert. + assert(spark.table(targetTableName).count() === numRows) + } + } + } } // MERGE + materialize diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala index 2f14dde13ba..2bb7abec1c6 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/MergeIntoSuiteBase.scala @@ -520,6 +520,113 @@ abstract class MergeIntoSuiteBase } } + test("non-correlated scalar subquery in source query") { + withTable("source") { + Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others") + .createOrReplaceTempView("source") + append(Seq((2, 2), (1, 4)).toDF("key2", "value")) + + executeMerge( + target = s"delta.`$tempPath` as trg", + source = "(SELECT * FROM source WHERE value = (SELECT max(value) FROM source)) src", + condition = "src.key1 = trg.key2", + update = "trg.key2 = 20 + key1, value = 20 + src.value", + insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)") + + checkAnswer(readDeltaTable(tempPath), + Row(2, 2) :: // No change + Row(21, 26) :: // UPDATE + Nil) + } + } + + test("correlated scalar subquery in source query") { + withTable("source") { + Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others") + .createOrReplaceTempView("source") + append(Seq((2, 2), (1, 4)).toDF("key2", "value")) + + executeMerge( + target = s"delta.`$tempPath` as trg", + source = "(SELECT * FROM source WHERE " + + s"value = (SELECT MAX(value) FROM delta.`$tempPath` WHERE key1 = key2)) src", + condition = "src.key1 = trg.key2", + update = "trg.key2 = 20 + key1, value = 20 + src.value", + insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)") + + checkAnswer(readDeltaTable(tempPath), + Row(2, 2) :: // No change + Row(1, 4) :: // No change + Nil) + } + } + + test("non-correlated exists subquery in source query") { + withTable("source") { + Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others") + .createOrReplaceTempView("source") + append(Seq((2, 2), (1, 4)).toDF("key2", "value")) + + executeMerge( + target = s"delta.`$tempPath` as trg", + source = s"(SELECT * FROM source WHERE EXISTS (SELECT * FROM delta.`$tempPath`)) src", + condition = "src.key1 = trg.key2", + update = "trg.key2 = 20 + key1, value = 20 + src.value", + insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)") + + checkAnswer( + readDeltaTable(tempPath), + Row(2, 2) :: // No change + Row(21, 26) :: // Update + Row(-10, 13) :: // Insert + Nil) + } + } + + test("correlated exists subquery in source query") { + withTable("source") { + Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others") + .createOrReplaceTempView("source") + append(Seq((2, 2), (1, 4)).toDF("key2", "value")) + + executeMerge( + target = s"delta.`$tempPath` as trg", + source = s"(SELECT * FROM source WHERE " + + s"EXISTS (SELECT * FROM delta.`$tempPath` WHERE key1 = key2)) src", + condition = "src.key1 = trg.key2", + update = "trg.key2 = 20 + key1, value = 20 + src.value", + insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)") + + checkAnswer( + readDeltaTable(tempPath), + Row(2, 2) :: // No change + Row(21, 26) :: // Update + Nil) + } + } + + test("in subquery in source query") { + withTable("source") { + Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others") + .createOrReplaceTempView("source") + append(Seq((2, 2), (1, 4)).toDF("key2", "value")) + + executeMerge( + target = s"delta.`$tempPath` as trg", + source = s"(SELECT * FROM source WHERE " + + s"key1 IN (SELECT key2 FROM delta.`$tempPath`)) src", + condition = "src.key1 = trg.key2", + update = "trg.key2 = 20 + key1, value = 20 + src.value", + insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)") + + checkAnswer( + readDeltaTable(tempPath), + Row(2, 2) :: // No change + Row(21, 26) :: // Update + Nil) + } + } + testQuietly("Negative case - more than one source rows match the same target row") { withTable("source") { Seq((1, 1), (0, 3), (1, 5)).toDF("key1", "value").createOrReplaceTempView("source")