Skip to content

Commit

Permalink
Inline non-correlated scalar subqueries in Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvanbussel committed Oct 4, 2023
1 parent 5f9b98e commit 0a54980
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ import java.time.{Instant, LocalDateTime}
import java.util.Locale

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.sql.delta.commands.MergeIntoCommand
import org.apache.spark.sql.delta.sources.DeltaSQLConf

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
Expand All @@ -35,6 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.CURRENT_LIKE
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
Expand Down Expand Up @@ -200,11 +200,15 @@ case class PreprocessTableMerge(override val conf: SQLConf)
* invoking the [[ComputeCurrentTime]] rule. This is why they need special handling.
*/
val now = Instant.now()

val sourceWithInlinedSubqueries =
inlineSubqueryResults(SparkSession.active, transformTimestamps(source, now))

// Transform timestamps for the MergeIntoCommand, source, and target using the same instant.
// Called explicitly because source and target are not children of MergeIntoCommand.
transformTimestamps(
MergeIntoCommand(
transformTimestamps(source, now),
sourceWithInlinedSubqueries,
transformTimestamps(target, now),
relation.catalogTable,
tahoeFileIndex,
Expand All @@ -227,6 +231,43 @@ case class PreprocessTableMerge(override val conf: SQLConf)
}
}

/**
* Inlines the results of the subqueries in the `source` of the `MergeIntoCommand`.
* This is necessary to deal with "deterministic" scalar subqueries that can return
* non-deterministic results. E.g. a query with a LIMIT 1 without an ORDER BY.
* In most cases these subqueries are evaluated only once as part of the source materialization,
* but scalar subqueries can be inferred from the materialized source and propagated to the target
* side of the join.
*/
private def inlineSubqueryResults(spark: SparkSession, source: LogicalPlan): LogicalPlan = {
// Gather all non-correlated scalar subqueries in the source.
val subqueries = source.flatMap {
_.expressions.flatMap(_.collect { case s: ScalarSubquery if !s.isCorrelated => s })
}
if (subqueries.isEmpty) {
return source
}

// Evaluate all non-correlated scalar subqueries in a single query to enable subquery reuse.
val namedSubqueries = subqueries.map { s =>
Alias(s, s"subquery-${s.exprId.id}")()
}
val qe = new QueryExecution(spark, Project(namedSubqueries, OneRowRelation()))
val result = SQLExecution.withNewExecutionId(qe) {
qe.executedPlan.executeCollect().head
}

// Replace the subqueries in the source and target with their results.
val subqueryResults = subqueries.zipWithIndex.map { case (s, i) =>
s.exprId.id -> Literal.create(result.get(i, s.dataType), s.dataType)
}.toMap
val newSource = source.transformAllExpressions {
case s: ScalarSubquery if !s.isCorrelated => subqueryResults(s.exprId.id)
}

newSource
}

private def transformTimestamps(plan: LogicalPlan, instant: Instant): LogicalPlan = {
import org.apache.spark.sql.delta.implicits._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.delta.util.JsonUtils
import org.scalactic.source.Position
import org.scalatest.Tag

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.{SparkConf, SparkException, TaskContext}
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
Expand Down Expand Up @@ -770,6 +770,47 @@ trait MergeIntoMaterializeSourceTests
assert(stats.materializeSourceReason.isDefined)
stats.materializeSourceReason.get
}

test("merge with non-deterministic scalar subquery in source") {
{
val numRows = 1000

val chooseOneRow = udf { i: Long =>
val stageId = TaskContext.get().stageId()
stageId % numRows == i
}
assert(chooseOneRow.deterministic)

// Create a subquery that returns a different row every time it is executed.
val subqueryViewName = "subquery_view"
spark.range(start = 0, end = numRows, step = 1, numPartitions = 1)
.filter(chooseOneRow(col("id")))
.createTempView(subqueryViewName)

val targetTableName = "target_table"
val sourceTableName = "source_table"
withTable(targetTableName, sourceTableName) {
spark.range(numRows).select(col("id").as("key"), col("id").as("value"))
.write.mode("overwrite").format("delta").saveAsTable(targetTableName)
spark.sql(s"SELECT key, value + $numRows AS value FROM $targetTableName")
.write.mode("overwrite").format("delta").saveAsTable(sourceTableName)

spark.sql(
s"""MERGE INTO $targetTableName t
|USING (SELECT * FROM $sourceTableName WHERE key = (SELECT * FROM $subqueryViewName)) s
|ON t.key = s.key
|WHEN MATCHED THEN UPDATE SET *
|WHEN NOT MATCHED THEN INSERT *""".stripMargin
)

// No new rows should have been inserted, as all keys in the source are already present in
// the target. If the subquery is evaluated multiple times, however, then the source may
// return different rows when finding the touched files and when writing the modified rows,
// in which case an update may be incorrectly treated as an insert.
assert(spark.table(targetTableName).count() === numRows)
}
}
}
}

// MERGE + materialize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,113 @@ abstract class MergeIntoSuiteBase
}
}

test("non-correlated scalar subquery in source query") {
withTable("source") {
Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others")
.createOrReplaceTempView("source")
append(Seq((2, 2), (1, 4)).toDF("key2", "value"))

executeMerge(
target = s"delta.`$tempPath` as trg",
source = "(SELECT * FROM source WHERE value = (SELECT max(value) FROM source)) src",
condition = "src.key1 = trg.key2",
update = "trg.key2 = 20 + key1, value = 20 + src.value",
insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)")

checkAnswer(readDeltaTable(tempPath),
Row(2, 2) :: // No change
Row(21, 26) :: // UPDATE
Nil)
}
}

test("correlated scalar subquery in source query") {
withTable("source") {
Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others")
.createOrReplaceTempView("source")
append(Seq((2, 2), (1, 4)).toDF("key2", "value"))

executeMerge(
target = s"delta.`$tempPath` as trg",
source = "(SELECT * FROM source WHERE " +
s"value = (SELECT MAX(value) FROM delta.`$tempPath` WHERE key1 = key2)) src",
condition = "src.key1 = trg.key2",
update = "trg.key2 = 20 + key1, value = 20 + src.value",
insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)")

checkAnswer(readDeltaTable(tempPath),
Row(2, 2) :: // No change
Row(1, 4) :: // No change
Nil)
}
}

test("non-correlated exists subquery in source query") {
withTable("source") {
Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others")
.createOrReplaceTempView("source")
append(Seq((2, 2), (1, 4)).toDF("key2", "value"))

executeMerge(
target = s"delta.`$tempPath` as trg",
source = s"(SELECT * FROM source WHERE EXISTS (SELECT * FROM delta.`$tempPath`)) src",
condition = "src.key1 = trg.key2",
update = "trg.key2 = 20 + key1, value = 20 + src.value",
insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)")

checkAnswer(
readDeltaTable(tempPath),
Row(2, 2) :: // No change
Row(21, 26) :: // Update
Row(-10, 13) :: // Insert
Nil)
}
}

test("correlated exists subquery in source query") {
withTable("source") {
Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others")
.createOrReplaceTempView("source")
append(Seq((2, 2), (1, 4)).toDF("key2", "value"))

executeMerge(
target = s"delta.`$tempPath` as trg",
source = s"(SELECT * FROM source WHERE " +
s"EXISTS (SELECT * FROM delta.`$tempPath` WHERE key1 = key2)) src",
condition = "src.key1 = trg.key2",
update = "trg.key2 = 20 + key1, value = 20 + src.value",
insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)")

checkAnswer(
readDeltaTable(tempPath),
Row(2, 2) :: // No change
Row(21, 26) :: // Update
Nil)
}
}

test("in subquery in source query") {
withTable("source") {
Seq((1, 6, "a"), (0, 3, "b")).toDF("key1", "value", "others")
.createOrReplaceTempView("source")
append(Seq((2, 2), (1, 4)).toDF("key2", "value"))

executeMerge(
target = s"delta.`$tempPath` as trg",
source = s"(SELECT * FROM source WHERE " +
s"key1 IN (SELECT key2 FROM delta.`$tempPath`)) src",
condition = "src.key1 = trg.key2",
update = "trg.key2 = 20 + key1, value = 20 + src.value",
insert = "(trg.key2, value) VALUES (key1 - 10, src.value + 10)")

checkAnswer(
readDeltaTable(tempPath),
Row(2, 2) :: // No change
Row(21, 26) :: // Update
Nil)
}
}

testQuietly("Negative case - more than one source rows match the same target row") {
withTable("source") {
Seq((1, 1), (0, 3), (1, 5)).toDF("key1", "value").createOrReplaceTempView("source")
Expand Down

0 comments on commit 0a54980

Please sign in to comment.