From 4021d911ec5d4b229aa927a1cec32fc54453d651 Mon Sep 17 00:00:00 2001 From: Milan Cupac Date: Fri, 24 Jan 2025 21:13:41 +0800 Subject: [PATCH] [SPARK-50838][SQL] Performs additional checks inside recursive CTEs to throw an error if forbidden case is encountered ### What changes were proposed in this pull request? Performs additional checks inside recursive CTEs to throw an error if forbidden case is encountered: 1. Recursive term can contain one recursive reference only. 2. Recursive reference can't be used in some kinds of joins and aggregations. 3. Recursive references are not allowed in subqueries In addition, the name of `recursive` function inside `CTERelationDef` is rewritten to `hasRecursiveCTERelationRef` and adds `hasItsOwnUnionLoopRef` function as it is also needed to check if cteDef is recursive after substitution. A small bug in `CTESubstitution` is fixed which now enables substitution of self-references within subqueries as well (but not its resolution, as they are not allowed). ### Why are the changes needed? Support for the recursive CTE. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? will be tested in https://github.com/apache/spark/pull/49571 ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49518 from milanisvet/checkRecursion. Authored-by: Milan Cupac Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 18 ++++ .../catalyst/analysis/CTESubstitution.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 ++ .../catalyst/analysis/ResolveWithCTE.scala | 82 +++++++++++++++++-- .../sql/catalyst/optimizer/InlineCTE.scala | 5 +- .../catalyst/plans/logical/cteOperators.scala | 8 +- 6 files changed, 112 insertions(+), 11 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 13fdbce0211fb..954fcdf3b4d93 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3135,6 +3135,24 @@ ], "sqlState" : "42836" }, + "INVALID_RECURSIVE_REFERENCE" : { + "message" : [ + "Invalid recursive reference found inside WITH RECURSIVE clause." + ], + "subClass" : { + "NUMBER" : { + "message" : [ + "Multiple self-references to one recursive CTE are not allowed." + ] + }, + "PLACE" : { + "message" : [ + "Recursive references cannot be used on the right side of left outer/semi/anti joins, on the left side of right outer joins, in full outer joins, in aggregates, and in subquery expressions." + ] + } + }, + "sqlState" : "42836" + }, "INVALID_REGEXP_REPLACE" : { "message" : [ "Could not perform regexp_replace for source = \"\", pattern = \"\", replacement = \"\" and position = ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 0d2b0464251f2..a86cc4555ccb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -402,7 +402,7 @@ object CTESubstitution extends Rule[LogicalPlan] { other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { case e: SubqueryExpression => e.withNewPlan( - apply(substituteCTE(e.plan, alwaysInline, cteRelations, None))) + apply(substituteCTE(e.plan, alwaysInline, cteRelations, recursiveCTERelation))) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 397f04ac984a2..c7d5c355270f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -22,6 +22,7 @@ import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.{checkForSelfReferenceInSubquery, checkIfSelfReferenceIsPlacedCorrectly} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, ListAgg, Median, PercentileCont, PercentileDisc} @@ -274,10 +275,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkTrailingCommaInSelect(proj) case agg: Aggregate => checkTrailingCommaInSelect(agg) + case unionLoop: UnionLoop => + // Recursive CTEs have already substituted Union to UnionLoop at this stage. + // Here we perform additional checks for them. + checkIfSelfReferenceIsPlacedCorrectly(unionLoop, unionLoop.id) case _ => } + // Check if there is any self-reference within subqueries + checkForSelfReferenceInSubquery(plan) + // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala index 3ad88514e17c7..454fcdbd38399 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveWithCTE.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression +import org.apache.spark.sql.catalyst.plans.{Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{CTE, PLAN_EXPRESSION} @@ -49,17 +50,18 @@ object ResolveWithCTE extends Rule[LogicalPlan] { plan.resolveOperatorsDownWithPruning(_.containsAllPatterns(CTE)) { case withCTE @ WithCTE(_, cteDefs) => val newCTEDefs = cteDefs.map { - // `cteDef.recursive` means "presence of a recursive CTERelationRef under cteDef". The - // side effect of node substitution below is that after CTERelationRef substitution - // its cteDef is no more considered `recursive`. This code path is common for `cteDef` - // that were non-recursive from the get go, as well as those that are no more recursive - // due to node substitution. - case cteDef if !cteDef.recursive => + // cteDef in the first case is either recursive and all the recursive CTERelationRefs + // are already substituted to UnionLoopRef in the previous pass, or it is not recursive + // at all. In both cases we need to put it in the map in case it is resolved. + // Second case is performing the substitution of recursive CTERelationRefs. + case cteDef if !cteDef.hasSelfReferenceAsCTERef => if (cteDef.resolved) { cteDefMap.put(cteDef.id, cteDef) } cteDef case cteDef => + // Multiple self-references are not allowed within one cteDef. + checkNumberOfSelfReferences(cteDef) cteDef.child match { // If it is a supported recursive CTE query pattern (4 so far), extract the anchor and // recursive plans from the Union and rewrite Union with UnionLoop. The recursive CTE @@ -183,4 +185,72 @@ object ResolveWithCTE extends Rule[LogicalPlan] { columnNames.map(UnresolvedSubqueryColumnAliases(_, ref)).getOrElse(ref) } } + + /** + * Checks if there is any self-reference within subqueries and throws an error + * if that is the case. + */ + def checkForSelfReferenceInSubquery(plan: LogicalPlan): Unit = { + plan.subqueriesAll.foreach { subquery => + subquery.foreach { + case r: CTERelationRef if r.recursive => + throw new AnalysisException( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty) + case _ => + } + } + } + + /** + * Counts number of self-references in a recursive CTE definition and throws an error + * if that number is bigger than 1. + */ + private def checkNumberOfSelfReferences(cteDef: CTERelationDef): Unit = { + val numOfSelfRef = cteDef.collectWithSubqueries { + case ref: CTERelationRef if ref.cteId == cteDef.id => ref + }.length + if (numOfSelfRef > 1) { + cteDef.failAnalysis( + errorClass = "INVALID_RECURSIVE_REFERENCE.NUMBER", + messageParameters = Map.empty) + } + } + + /** + * Throws error if self-reference is placed in places which are not allowed: + * right side of left outer/semi/anti joins, left side of right outer joins, + * in full outer joins and in aggregates + */ + def checkIfSelfReferenceIsPlacedCorrectly( + plan: LogicalPlan, + cteId: Long, + allowRecursiveRef: Boolean = true): Unit = plan match { + case Join(left, right, Inner, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef) + case Join(left, right, LeftOuter, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, RightOuter, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef) + case Join(left, right, LeftSemi, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, LeftAnti, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Join(left, right, _, _, _) => + checkIfSelfReferenceIsPlacedCorrectly(left, cteId, allowRecursiveRef = false) + checkIfSelfReferenceIsPlacedCorrectly(right, cteId, allowRecursiveRef = false) + case Aggregate(_, _, child, _) => + checkIfSelfReferenceIsPlacedCorrectly(child, cteId, allowRecursiveRef = false) + case r: UnionLoopRef if !allowRecursiveRef && r.loopId == cteId => + throw new AnalysisException( + errorClass = "INVALID_RECURSIVE_REFERENCE.PLACE", + messageParameters = Map.empty) + case other => + other.children.foreach(checkIfSelfReferenceIsPlacedCorrectly(_, cteId, allowRecursiveRef)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index ad1a1a99b8257..62d99f7854891 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -61,7 +61,10 @@ case class InlineCTE( // 1) It is fine to inline a CTE if it references another CTE that is non-deterministic; // 2) Any `CTERelationRef` that contains `OuterReference` would have been inlined first. refCount == 1 || - cteDef.deterministic || + // Don't inline recursive CTEs if not necessary as recursion is very costly. + // The check if cteDef is recursive is performed by checking if it contains + // a UnionLoopRef with the same ID. + (cteDef.deterministic && !cteDef.hasSelfReferenceAsUnionLoopRef) || cteDef.child.exists(_.expressions.exists(_.isInstanceOf[OuterReference])) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala index ac4e2b77f793a..03236571eade7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/cteOperators.scala @@ -100,12 +100,14 @@ case class CTERelationDef( override def output: Seq[Attribute] = if (resolved) child.output else Nil - lazy val recursive: Boolean = child.exists{ - // If the reference is found inside the child, referencing to this CTE definition, - // and already marked as recursive, then this CTE definition is recursive. + lazy val hasSelfReferenceAsCTERef: Boolean = child.exists{ case CTERelationRef(this.id, _, _, _, _, true) => true case _ => false } + lazy val hasSelfReferenceAsUnionLoopRef: Boolean = child.exists{ + case UnionLoopRef(this.id, _, _) => true + case _ => false + } } object CTERelationDef {