diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 84e52282b632f..b4978fbe1f70a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1267,6 +1267,9 @@ case class Shuffle(child: Expression, randomSeed: Option[Long] = None) extends U override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed)) + override def withShiftedSeed(shift: Long): Shuffle = + copy(randomSeed = randomSeed.map(_ + shift)) + override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e7d3701544c54..dcbca34b240b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -260,6 +260,8 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def withNewSeed(seed: Long): Uuid = Uuid(Some(seed)) + override def withShiftedSeed(shift: Long): Uuid = Uuid(randomSeed.map(_ + shift)) + override lazy val resolved: Boolean = randomSeed.isDefined override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index fa6eb2c111895..06cc6e55c8ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -76,6 +76,7 @@ trait ExpressionWithRandomSeed extends Expression { def seedExpression: Expression def withNewSeed(seed: Long): Expression + def withShiftedSeed(shift: Long): Expression } private[catalyst] object ExpressionWithRandomSeed { @@ -114,6 +115,9 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Rand = + Rand(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed) + override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -165,6 +169,9 @@ case class Randn(child: Expression, hideSeed: Boolean = false) extends Nondeterm override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Randn = + Randn(Add(child, Literal(shift), evalMode = EvalMode.LEGACY), hideSeed) + override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -268,6 +275,9 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, override def withNewSeed(newSeed: Long): Expression = Uniform(min, max, Literal(newSeed, LongType), hideSeed) + override def withShiftedSeed(shift: Long): Expression = + Uniform(min, max, Literal(seed + shift, LongType), hideSeed) + override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = Uniform(newFirst, newSecond, newThird, hideSeed) @@ -348,6 +358,10 @@ case class RandStr( override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType), hideSeed) + + override def withShiftedSeed(shift: Long): Expression = + RandStr(length, Literal(seed + shift, LongType), hideSeed) + override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = RandStr(newFirst, newSecond, hideSeed) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala index 33188db5d23b0..d44d3b0b6ef0b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnionLoopExec.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, InterpretedMutableProjection, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionWithRandomSeed, InterpretedMutableProjection, Literal} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation.hasUnevaluableExpr import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, LocalRelation, LogicalPlan, OneRowRelation, Project, Union, UnionLoopRef} @@ -183,11 +183,24 @@ case class UnionLoopExec( // Main loop for obtaining the result of the recursive query. while (prevCount > 0 && !limitReached) { var prevPlan: LogicalPlan = null + + // If the recursive part contains non-deterministic expressions that depends on a seed, we + // need to create a new seed since the seed for this expression is set in the analysis, and + // we avoid re-triggering the analysis for every iterative step. + val recursionReseeded = if (currentLevel == 1 || recursion.deterministic) { + recursion + } else { + recursion.transformAllExpressionsWithSubqueries { + case e: ExpressionWithRandomSeed => + e.withShiftedSeed(currentLevel - 1) + } + } + // the current plan is created by substituting UnionLoopRef node with the project node of // the previous plan. // This way we support only UNION ALL case. Additional case should be added for UNION case. // One way of supporting UNION case can be seen at SPARK-24497 PR from Peter Toth. - val newRecursion = recursion.transformWithSubqueries { + val newRecursion = recursionReseeded.transformWithSubqueries { case r: UnionLoopRef if r.loopId == loopId => prevDF.queryExecution.optimizedPlan match { case l: LocalRelation => diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out index 4aff038838654..dc2b5a20fde51 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/cte-recursion.sql.out @@ -1629,3 +1629,75 @@ WithCTE +- Project [n#x] +- SubqueryAlias t1 +- CTERelationRef xxxx, true, [n#x], false, false + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql index fba8861083be4..8ef0c391a3fc5 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/cte-recursion.sql @@ -586,4 +586,58 @@ WITH RECURSIVE t1 AS ( SELECT 1 AS n UNION ALL SELECT n+1 FROM t2 WHERE n < 5) -SELECT * FROM t1; \ No newline at end of file +SELECT * FROM t1; + +-- Non-deterministic query with rand with seed +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with uniform with seed +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with randn with seed +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with randstr +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with UUID +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; + +-- Non-deterministic query with shuffle +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out index 06f440a3f6335..d6939ab84b57c 100644 --- a/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/cte-recursion.sql.out @@ -1475,3 +1475,111 @@ struct 3 4 5 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(rand(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(rand(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +1 +3 +4 +4 +5 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(UNIFORM(1, 6, 82374) AS INT) + UNION ALL + SELECT CAST(UNIFORM(1, 6, 237685) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +1 +3 +4 +4 +5 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT CAST(floor(randn(82374) * 5 + 1) AS INT) + UNION ALL + SELECT CAST(floor(randn(237685) * 5 + 1) AS INT) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +-2 +2 +2 +5 +6 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT randstr(10, 82374) + UNION ALL + SELECT randstr(10, 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +IpXzdTW03I +Zj7uI2Ex6e +dBlWnfo7rO +fmfDBMf60f +kFeBV7dQWi + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT UUID(82374) + UNION ALL + SELECT UUID(237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct +-- !query output +19974dca-21f6-47ef-b58c-73908ab52aa0 +4ea190e3-c088-4ddd-a545-fb431059ae3c +8b88900e-f862-468c-8d3b-828188116155 +be4f5346-1c7f-4697-8a2c-1343347872c5 +d0032efe-ae60-461b-8582-f6a7c649f238 + + +-- !query +WITH RECURSIVE randoms(val) AS ( + SELECT ARRAY(1,2,3,4,5) + UNION ALL + SELECT SHUFFLE(ARRAY(1,2,3,4,5), 237685) + FROM randoms +) +SELECT val FROM randoms LIMIT 5 +-- !query schema +struct> +-- !query output +[1,2,3,4,5] +[1,2,3,5,4] +[2,1,5,3,4] +[4,3,2,5,1] +[4,5,1,2,3]