diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala index 4d547f771b49..daa195b68c58 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala @@ -17,12 +17,12 @@ package io.glutenproject.execution import io.glutenproject.backendsapi.BackendsApiManager -import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, LiteralTransformer} +import io.glutenproject.expression.{ConverterUtils, ExpressionConverter} import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} import io.glutenproject.substrait.SubstraitContext -import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} +import io.glutenproject.substrait.expression.ExpressionNode import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} @@ -32,9 +32,6 @@ import org.apache.spark.sql.execution._ import java.util.{ArrayList => JArrayList, List => JList} -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - case class ExpandExecTransformer( projections: Seq[Seq[Expression]], output: Seq[Attribute], @@ -66,110 +63,33 @@ case class ExpandExecTransformer( input: RelNode, validation: Boolean): RelNode = { val args = context.registeredFunction - def needsPreProjection(projections: Seq[Seq[Expression]]): Boolean = { - projections - .exists(set => set.exists(p => !p.isInstanceOf[Attribute] && !p.isInstanceOf[Literal])) - } - if (needsPreProjection(projections)) { - // if there is not literal and attribute expression in project sets, add a project op - // to calculate them before expand op. - val preExprs = ArrayBuffer.empty[Expression] - val selectionMaps = ArrayBuffer.empty[Seq[Int]] - var preExprIndex = 0 - for (i <- projections.indices) { - val selections = ArrayBuffer.empty[Int] - for (j <- projections(i).indices) { - val proj = projections(i)(j) - if (!proj.isInstanceOf[Literal]) { - val exprIdx = preExprs.indexWhere(expr => expr.semanticEquals(proj)) - if (exprIdx != -1) { - selections += exprIdx - } else { - preExprs += proj - selections += preExprIndex - preExprIndex = preExprIndex + 1 - } - } else { - selections += -1 - } - } - selectionMaps += selections - } - // make project - val preExprNodes = preExprs - .map( - ExpressionConverter - .replaceWithExpressionTransformer(_, originalInputAttributes) - .doTransform(args)) - .asJava - - val emitStartIndex = originalInputAttributes.size - val inputRel = if (!validation) { - RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, emitStartIndex) - } else { - // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeProjectRel( - input, - preExprNodes, - extensionNode, - context, - operatorId, - emitStartIndex) - } - - // make expand - val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() - for (i <- projections.indices) { + val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() + projections.foreach { + projectSet => val projectExprNodes = new JArrayList[ExpressionNode]() - for (j <- projections(i).indices) { - val projectExprNode = projections(i)(j) match { - case l: Literal => - LiteralTransformer(l).doTransform(args) - case _ => - ExpressionBuilder.makeSelection(selectionMaps(i)(j)) - } - - projectExprNodes.add(projectExprNode) + projectSet.foreach { + project => + val projectExprNode = ExpressionConverter + .replaceWithExpressionTransformer(project, originalInputAttributes) + .doTransform(args) + projectExprNodes.add(projectExprNode) } projectSetExprNodes.add(projectExprNodes) - } - RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context, operatorId) + } + + if (!validation) { + RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId) } else { - val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() - projections.foreach { - projectSet => - val projectExprNodes = new JArrayList[ExpressionNode]() - projectSet.foreach { - project => - val projectExprNode = ExpressionConverter - .replaceWithExpressionTransformer(project, originalInputAttributes) - .doTransform(args) - projectExprNodes.add(projectExprNode) - } - projectSetExprNodes.add(projectExprNodes) + // Use a extension node to send the input types through Substrait plan for a validation. + val inputTypeNodeList = new java.util.ArrayList[TypeNode]() + for (attr <- originalInputAttributes) { + inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } - if (!validation) { - RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId) - } else { - // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } - - val extensionNode = ExtensionBuilder.makeAdvancedExtension( - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) - RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId) - } + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala index 5bf70597c84d..440f609de92d 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/PullOutPreProject.scala @@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec} +import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression} import org.apache.spark.sql.execution.window.WindowExec @@ -74,6 +74,7 @@ object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper { } case _ => false }.isDefined) + case expand: ExpandExec => expand.projections.flatten.exists(isNotAttributeAndLiteral) case _ => false } } @@ -179,6 +180,15 @@ object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper { ProjectExec(window.output, newWindow) + case expand: ExpandExec if needsPreProject(expand) => + val expressionMap = new mutable.HashMap[Expression, NamedExpression]() + val newProjections = + expand.projections.map(_.map(replaceExpressionWithAttribute(_, expressionMap))) + expand.copy( + projections = newProjections, + child = ProjectExec( + eliminateProjectList(expand.child.outputSet, expressionMap.values.toSeq), + expand.child)) case _ => plan } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala index 892e5eeef4e0..8f3f01f9570b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/columnar/RewriteSparkPlanRulesManager.scala @@ -53,6 +53,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R case _: WindowExec => true case _: FilterExec => true case _: FileSourceScanExec => true + case _: ExpandExec => true case _ => false } }