Skip to content

Commit

Permalink
Pullout pre-project for ExpandExec
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Mar 21, 2024
1 parent 0f5716f commit a29db9d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, LiteralTransformer}
import io.glutenproject.expression.{ConverterUtils, ExpressionConverter}
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.MetricsUpdater
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.ExtensionBuilder
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}

Expand All @@ -32,9 +32,6 @@ import org.apache.spark.sql.execution._

import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

case class ExpandExecTransformer(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
Expand Down Expand Up @@ -66,110 +63,33 @@ case class ExpandExecTransformer(
input: RelNode,
validation: Boolean): RelNode = {
val args = context.registeredFunction
def needsPreProjection(projections: Seq[Seq[Expression]]): Boolean = {
projections
.exists(set => set.exists(p => !p.isInstanceOf[Attribute] && !p.isInstanceOf[Literal]))
}
if (needsPreProjection(projections)) {
// if there is not literal and attribute expression in project sets, add a project op
// to calculate them before expand op.
val preExprs = ArrayBuffer.empty[Expression]
val selectionMaps = ArrayBuffer.empty[Seq[Int]]
var preExprIndex = 0
for (i <- projections.indices) {
val selections = ArrayBuffer.empty[Int]
for (j <- projections(i).indices) {
val proj = projections(i)(j)
if (!proj.isInstanceOf[Literal]) {
val exprIdx = preExprs.indexWhere(expr => expr.semanticEquals(proj))
if (exprIdx != -1) {
selections += exprIdx
} else {
preExprs += proj
selections += preExprIndex
preExprIndex = preExprIndex + 1
}
} else {
selections += -1
}
}
selectionMaps += selections
}
// make project
val preExprNodes = preExprs
.map(
ExpressionConverter
.replaceWithExpressionTransformer(_, originalInputAttributes)
.doTransform(args))
.asJava

val emitStartIndex = originalInputAttributes.size
val inputRel = if (!validation) {
RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, emitStartIndex)
} else {
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
preExprNodes,
extensionNode,
context,
operatorId,
emitStartIndex)
}

// make expand
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
for (i <- projections.indices) {
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
projections.foreach {
projectSet =>
val projectExprNodes = new JArrayList[ExpressionNode]()
for (j <- projections(i).indices) {
val projectExprNode = projections(i)(j) match {
case l: Literal =>
LiteralTransformer(l).doTransform(args)
case _ =>
ExpressionBuilder.makeSelection(selectionMaps(i)(j))
}

projectExprNodes.add(projectExprNode)
projectSet.foreach {
project =>
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(project, originalInputAttributes)
.doTransform(args)
projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
}
RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context, operatorId)
}

if (!validation) {
RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
} else {
val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]()
projections.foreach {
projectSet =>
val projectExprNodes = new JArrayList[ExpressionNode]()
projectSet.foreach {
project =>
val projectExprNode = ExpressionConverter
.replaceWithExpressionTransformer(project, originalInputAttributes)
.doTransform(args)
projectExprNodes.add(projectExprNode)
}
projectSetExprNodes.add(projectExprNodes)
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}

if (!validation) {
RelBuilder.makeExpandRel(input, projectSetExprNodes, context, operatorId)
} else {
// Use a extension node to send the input types through Substrait plan for a validation.
val inputTypeNodeList = new java.util.ArrayList[TypeNode]()
for (attr <- originalInputAttributes) {
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}

val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId)
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.glutenproject.utils.PullOutProjectHelper
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Partial}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.{ExpandExec, ProjectExec, SortExec, SparkPlan, TakeOrderedAndProjectExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, TypedAggregateExpression}
import org.apache.spark.sql.execution.window.WindowExec

Expand Down Expand Up @@ -74,6 +74,7 @@ object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper {
}
case _ => false
}.isDefined)
case expand: ExpandExec => expand.projections.flatten.exists(isNotAttributeAndLiteral)
case _ => false
}
}
Expand Down Expand Up @@ -179,6 +180,15 @@ object PullOutPreProject extends Rule[SparkPlan] with PullOutProjectHelper {

ProjectExec(window.output, newWindow)

case expand: ExpandExec if needsPreProject(expand) =>
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
val newProjections =
expand.projections.map(_.map(replaceExpressionWithAttribute(_, expressionMap)))
expand.copy(
projections = newProjections,
child = ProjectExec(
eliminateProjectList(expand.child.outputSet, expressionMap.values.toSeq),
expand.child))
case _ => plan
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class RewriteSparkPlanRulesManager(rewriteRules: Seq[Rule[SparkPlan]]) extends R
case _: WindowExec => true
case _: FilterExec => true
case _: FileSourceScanExec => true
case _: ExpandExec => true
case _ => false
}
}
Expand Down

0 comments on commit a29db9d

Please sign in to comment.