Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Mar 20, 2024
1 parent eb5a24b commit 95dccec
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,10 @@ package io.glutenproject.backendsapi.velox

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.SparkPlanExecApi
import io.glutenproject.backendsapi.velox.SparkPlanExecApiImpl.supportsGenerate
import io.glutenproject.exception.GlutenNotSupportException
import io.glutenproject.execution._
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.extension.RewriteCollect.generatePreAliasName
import io.glutenproject.extension.columnar.PullOutPostProject.generatePostAliasName
import io.glutenproject.extension.columnar.TransformHints
import io.glutenproject.sql.shims.SparkShimLoader
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode, IfThenNode}
Expand All @@ -40,7 +37,7 @@ import org.apache.spark.sql.catalyst.{AggregateFunctionRewriteRule, FlushableHas
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, ElementAt, ExplodeBase, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, Inline, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim, Subtract, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, Cast, CreateNamedStruct, ElementAt, Expression, ExpressionInfo, Generator, GetArrayItem, GetMapValue, GetStructField, If, IsNaN, Literal, Murmur3Hash, NamedExpression, NaNvl, PosExplode, Round, SortOrder, StringSplit, StringTrim}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, HLLAdapter}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -664,79 +661,12 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
}

override def genPreProjectForGenerate(generate: GenerateExec): SparkPlan = {
if (supportsGenerate(generate)) {
val newGeneratorChild = generate.generator.asInstanceOf[UnaryExpression].child match {
case attr: Attribute => attr
case expr @ (Literal(_, _) | CreateMap(_, _) | CreateArray(_, _)) =>
Alias(expr, generatePreAliasName)()
case other =>
throw new UnsupportedOperationException(
s"Fail to execute ${generate.generator.getClass.getSimpleName} " +
s"with child type ${other.getClass.getSimpleName}")
}
generate.copy(
generator =
generate.generator.withNewChildren(Seq(newGeneratorChild)).asInstanceOf[Generator],
child = ProjectExec(generate.child.output :+ newGeneratorChild, generate.child)
)
} else {
generate
}
PullOutGenerateProjectHelper.pullOutPreProject(generate)
}

override def genPostProjectForGenerate(generate: GenerateExec): SparkPlan = {
if (supportsGenerate(generate)) {
generate.generator match {
case PosExplode(_) =>
val originalOrdinal = generate.generatorOutput.head
val ordinal = {
val subtract = Subtract(Cast(originalOrdinal, IntegerType), Literal(1))
Alias(subtract, generatePostAliasName)(
originalOrdinal.exprId,
originalOrdinal.qualifier)
}
val newGenerate =
generate.copy(
generatorOutput = generate.generatorOutput.tail :+ originalOrdinal,
generator = GlutenGeneratorWrapper(generate.generator))
ProjectExec(
(generate.requiredChildOutput :+ ordinal) ++ generate.generatorOutput.tail,
newGenerate)
case Inline(_) =>
val unnestOutput = {
val struct = CreateStruct(generate.generatorOutput)
val alias = Alias(struct, generatePostAliasName)()
alias.toAttribute
}
val newGenerate = generate.copy(
generatorOutput = Seq(unnestOutput),
generator = GlutenGeneratorWrapper(generate.generator)
)
val newOutput = generate.generatorOutput.zipWithIndex.map {
case (attr, i) =>
val getStructField = GetStructField(unnestOutput, i, Some(attr.name))
Alias(getStructField, generatePostAliasName)(attr.exprId, attr.qualifier)
}
ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate)
case _ => generate
}
} else {
generate
}
PullOutGenerateProjectHelper.pullOutPostProject(generate)
}
}

object SparkPlanExecApiImpl {
def supportsGenerate(generate: GenerateExec): Boolean = {
if (generate.outer) {
false
} else {
generate.generator match {
case _: Inline | _: ExplodeBase =>
true
case _ =>
false
}
}
}
}
object SparkPlanExecApiImpl {}
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,24 @@
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution.GenerateExecTransformer.supportsGenerate
import io.glutenproject.extension.ValidationResult
import io.glutenproject.metrics.{GenerateMetricsUpdater, MetricsUpdater}
import io.glutenproject.substrait.SubstraitContext
import io.glutenproject.substrait.expression.ExpressionNode
import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder}
import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import io.glutenproject.utils.PullOutProjectHelper

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.{GenerateExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.IntegerType

import com.google.protobuf.StringValue

import scala.collection.JavaConverters._
import scala.collection.mutable

case class GenerateExecTransformer(
generator: Generator,
Expand All @@ -57,14 +61,11 @@ case class GenerateExecTransformer(
override protected def doGeneratorValidate(
generator: Generator,
outer: Boolean): ValidationResult = {
if (outer) {
return ValidationResult.notOk(s"Velox backend does not support outer")
}
generator match {
case _: JsonTuple =>
ValidationResult.notOk(s"Velox backend does not support this json_tuple")
case _ =>
ValidationResult.ok
if (!supportsGenerate(generator, outer)) {
ValidationResult.notOk(
s"Velox backend does not support this generator: ${generator.getClass.getSimpleName}, outer: $outer")
} else {
ValidationResult.ok
}
}

Expand Down Expand Up @@ -108,3 +109,99 @@ case class GenerateExecTransformer(
}
}
}

object GenerateExecTransformer {
def supportsGenerate(generator: Generator, outer: Boolean): Boolean = {
// TODO: supports outer and remove this param.
if (outer) {
false
} else {
generator match {
case _: Inline | _: ExplodeBase =>
true
case _ =>
false
}
}
}
}

object PullOutGenerateProjectHelper extends PullOutProjectHelper {
def pullOutPreProject(generate: GenerateExec): SparkPlan = {
if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) {
val newGeneratorChildren = generate.generator match {
case _: Inline | _: ExplodeBase =>
val expressionMap = new mutable.HashMap[Expression, NamedExpression]()
// The new child should be either the original Attribute,
// or an Alias to other expressions.
val generatorAttr = replaceExpressionWithAttribute(
generate.generator.asInstanceOf[UnaryExpression].child,
expressionMap,
replaceBoundReference = false)
val newGeneratorChild = if (expressionMap.isEmpty) {
// generator.child is Attribute
generatorAttr.asInstanceOf[Attribute]
} else {
// generator.child is other expression, e.g Literal/CreateArray/CreateMap
expressionMap.values.head
}
Seq(newGeneratorChild)
case _ =>
// Unreachable.
throw new IllegalStateException(
s"Generator ${generate.generator.getClass.getSimpleName} is not supported.")
}
// Avoid using elimainateProjectList to create the project list
// because newGeneratorChild can be a duplicated Attribute in generate.child.output.
// The native side identifies the last field of projection as generator's input.
generate.copy(
generator =
generate.generator.withNewChildren(newGeneratorChildren).asInstanceOf[Generator],
child = ProjectExec(generate.child.output ++ newGeneratorChildren, generate.child)
)
} else {
generate
}
}

def pullOutPostProject(generate: GenerateExec): SparkPlan = {
if (GenerateExecTransformer.supportsGenerate(generate.generator, generate.outer)) {
generate.generator match {
case PosExplode(_) =>
val originalOrdinal = generate.generatorOutput.head
val ordinal = {
val subtract = Subtract(Cast(originalOrdinal, IntegerType), Literal(1))
Alias(subtract, generatePostAliasName)(
originalOrdinal.exprId,
originalOrdinal.qualifier)
}
val newGenerate =
generate.copy(
generatorOutput = generate.generatorOutput.tail :+ originalOrdinal,
generator = GlutenGeneratorWrapper(generate.generator))
ProjectExec(
(generate.requiredChildOutput :+ ordinal) ++ generate.generatorOutput.tail,
newGenerate)
case Inline(_) =>
val unnestOutput = {
val struct = CreateStruct(generate.generatorOutput)
val alias = Alias(struct, generatePostAliasName)()
alias.toAttribute
}
val newGenerate = generate.copy(
generatorOutput = Seq(unnestOutput),
generator = GlutenGeneratorWrapper(generate.generator)
)
val newOutput = generate.generatorOutput.zipWithIndex.map {
case (attr, i) =>
val getStructField = GetStructField(unnestOutput, i, Some(attr.name))
Alias(getStructField, generatePostAliasName)(attr.exprId, attr.qualifier)
}
ProjectExec(generate.requiredChildOutput ++ newOutput, newGenerate)
case _ => generate
}
} else {
generate
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,14 @@ class TestOperator extends VeloxWholeStageTransformerSuite {
checkOperatorMatch[GenerateExecTransformer]
}
}

// Fallback for array(struct(...), null) literal.
runQueryAndCompare(s"""
|SELECT inline(array(
| named_struct('c1', 0, 'c2', 1),
| named_struct('c1', 2, 'c2', null),
| null));
|""".stripMargin)(_)
}

test("test array functions") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ trait PullOutProjectHelper {

private val generatedNameIndex = new AtomicInteger(0)

def generatePreAliasName: String = s"_pre_${generatedNameIndex.getAndIncrement()}"
def generatePostAliasName: String = s"_post_${generatedNameIndex.getAndIncrement()}"
protected def generatePreAliasName: String = s"_pre_${generatedNameIndex.getAndIncrement()}"
protected def generatePostAliasName: String = s"_post_${generatedNameIndex.getAndIncrement()}"

/**
* The majority of Expressions only support Attribute and BoundReference when converting them into
Expand All @@ -57,12 +57,13 @@ trait PullOutProjectHelper {

protected def replaceExpressionWithAttribute(
expr: Expression,
projectExprsMap: mutable.HashMap[Expression, NamedExpression]): Expression =
projectExprsMap: mutable.HashMap[Expression, NamedExpression],
replaceBoundReference: Boolean = true): Expression =
expr match {
case alias: Alias =>
projectExprsMap.getOrElseUpdate(alias.child.canonicalized, alias).toAttribute
case attr: Attribute => attr
case e: BoundReference => e
case e: BoundReference if !replaceBoundReference => e
case other =>
projectExprsMap
.getOrElseUpdate(other.canonicalized, Alias(other, generatePreAliasName)())
Expand Down

0 comments on commit 95dccec

Please sign in to comment.