diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala new file mode 100644 index 0000000000000..19a209a307f94 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NondeterministicExpressionCollection.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.HashMap + +import org.apache.spark.sql.catalyst.expressions._ + +object NondeterministicExpressionCollection { + def getNondeterministicToAttributes( + expressions: Seq[Expression]): HashMap[Expression, NamedExpression] = { + val nondeterministicToAttributes = new HashMap[Expression, NamedExpression] + for (expression <- expressions) { + if (!expression.deterministic) { + val leafNondeterministic = expression.collect { + case n: Nondeterministic => n + case udf: UserDefinedExpression if !udf.deterministic => udf + } + leafNondeterministic.distinct.foreach { + case n: NamedExpression => + nondeterministicToAttributes.put(expression, n) + case other => + nondeterministicToAttributes.put(other, Alias(other, "_nondeterministic")()) + } + } + } + nondeterministicToAttributes + } + + def tryConvertNondeterministicToAttribute( + expression: Expression, + nondeterministicToAttributes: HashMap[Expression, NamedExpression]): Expression = { + nondeterministicToAttributes.get(expression) match { + case null => expression + case other => + other.toAttribute + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala index 3955142166831..e9a04f55f46e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministic.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import scala.jdk.CollectionConverters._ + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule @@ -34,10 +36,14 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { case f: Filter => f case a: Aggregate if a.groupingExpressions.exists(!_.deterministic) => - val nondeterToAttr = getNondeterToAttr(a.groupingExpressions) - val newChild = Project(a.child.output ++ nondeterToAttr.values, a.child) + val nondeterToAttr = + NondeterministicExpressionCollection.getNondeterministicToAttributes(a.groupingExpressions) + val newChild = Project(a.child.output ++ nondeterToAttr.values.asScala.toSeq, a.child) a.transformExpressions { case e => - nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + NondeterministicExpressionCollection.tryConvertNondeterministicToAttribute( + e, + nondeterToAttr + ) }.copy(child = newChild) // Don't touch collect metrics. Top-level metrics are not supported (check analysis will fail) @@ -51,27 +57,15 @@ object PullOutNondeterministic extends Rule[LogicalPlan] { // from LogicalPlan, currently we only do it for UnaryNode which has same output // schema with its child. case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) => - val nondeterToAttr = getNondeterToAttr(p.expressions) + val nondeterToAttr = + NondeterministicExpressionCollection.getNondeterministicToAttributes(p.expressions) val newPlan = p.transformExpressions { case e => - nondeterToAttr.get(e).map(_.toAttribute).getOrElse(e) + NondeterministicExpressionCollection.tryConvertNondeterministicToAttribute( + e, + nondeterToAttr + ) } - val newChild = Project(p.child.output ++ nondeterToAttr.values, p.child) + val newChild = Project(p.child.output ++ nondeterToAttr.values.asScala.toSeq, p.child) Project(p.output, newPlan.withNewChildren(newChild :: Nil)) } - - private def getNondeterToAttr(exprs: Seq[Expression]): Map[Expression, NamedExpression] = { - exprs.filterNot(_.deterministic).flatMap { expr => - val leafNondeterministic = expr.collect { - case n: Nondeterministic => n - case udf: UserDefinedExpression if !udf.deterministic => udf - } - leafNondeterministic.distinct.map { e => - val ne = e match { - case n: NamedExpression => n - case _ => Alias(e, "_nondeterministic")() - } - e -> ne - } - }.toMap - } }