Skip to content

Commit

Permalink
[GLUTEN-7313][VL] Explicit Arrow transitions, part 2: new algorithm t…
Browse files Browse the repository at this point in the history
…o find optimal transition (apache#7372)
  • Loading branch information
zhztheplayer authored Sep 29, 2024
1 parent e29946d commit 255b0cc
Show file tree
Hide file tree
Showing 21 changed files with 521 additions and 181 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ import scala.util.control.Breaks.{break, breakable}

class CHBackend extends SubstraitBackend {
override def name(): String = CHConf.BACKEND_NAME
override def batchType: Convention.BatchType = CHBatch
override def defaultBatchType: Convention.BatchType = CHBatch
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("ClickHouse", CH_BRANCH, CH_COMMIT, "UNKNOWN")
override def iteratorApi(): IteratorApi = new CHIteratorApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.gluten.backendsapi.clickhouse

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.columnarbatch.CHBatch
import org.apache.gluten.execution.CHBroadcastBuildSideCache
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
Expand Down Expand Up @@ -68,6 +69,8 @@ class CHListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()

private def initialize(conf: SparkConf, isDriver: Boolean): Unit = {
// Force batch type initializations.
CHBatch.getClass
SparkDirectoryUtil.init(conf)
val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY)
if (StringUtils.isBlank(libPath)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ import org.apache.spark.sql.execution.{CHColumnarToRowExec, RowToCHNativeColumna
/**
* ClickHouse batch convention.
*
* [[fromRow]] and [[toRow]] need a [[TransitionDef]] instance. The scala allows an compact way to
* implement trait using a lambda function.
* [[fromRow]] and [[toRow]] need a
* [[org.apache.gluten.extension.columnar.transition.TransitionDef]] instance. The scala allows an
* compact way to implement trait using a lambda function.
*
* Here the detail definition is given in [[CHBatch.fromRow]].
* {{{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.extension.columnar.transition.ConventionFunc.BatchOverride
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionFunc}
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat
import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat.{DwrfReadFormat, OrcReadFormat, ParquetReadFormat}
Expand All @@ -37,7 +36,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CumeDist, DenseRank, De
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, ApproximatePercentile}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer
import org.apache.spark.sql.execution.{ColumnarCachedBatchSerializer, SparkPlan}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand
import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand}
Expand All @@ -51,14 +50,10 @@ import org.apache.hadoop.fs.Path
import scala.util.control.Breaks.breakable

class VeloxBackend extends SubstraitBackend {
import VeloxBackend._
override def name(): String = VeloxBackend.BACKEND_NAME
override def batchType: Convention.BatchType = VeloxBatch
override def batchTypeFunc(): BatchOverride = {
case i: InMemoryTableScanExec
if i.supportsColumnar && i.relation.cacheBuilder.serializer
.isInstanceOf[ColumnarCachedBatchSerializer] =>
VeloxBatch
}
override def defaultBatchType: Convention.BatchType = VeloxBatch
override def convFuncOverride(): ConventionFunc.Override = new ConvFunc()
override def buildInfo(): Backend.BuildInfo =
Backend.BuildInfo("Velox", VELOX_BRANCH, VELOX_REVISION, VELOX_REVISION_TIME)
override def iteratorApi(): IteratorApi = new VeloxIteratorApi
Expand All @@ -74,6 +69,15 @@ class VeloxBackend extends SubstraitBackend {
object VeloxBackend {
val BACKEND_NAME: String = "velox"
val CONF_PREFIX: String = GlutenConfig.prefixOf(BACKEND_NAME)

private class ConvFunc() extends ConventionFunc.Override {
override def batchTypeOf: PartialFunction[SparkPlan, Convention.BatchType] = {
case i: InMemoryTableScanExec
if i.supportsColumnar && i.relation.cacheBuilder.serializer
.isInstanceOf[ColumnarCachedBatchSerializer] =>
VeloxBatch
}
}
}

object VeloxBackendSettings extends BackendSettingsApi {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package org.apache.gluten.backendsapi.velox

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backendsapi.ListenerApi
import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch, ArrowNativeBatch}
import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter}
import org.apache.gluten.expression.UDFMappings
import org.apache.gluten.init.NativeBackendInitializer
Expand Down Expand Up @@ -119,6 +121,12 @@ class VeloxListenerApi extends ListenerApi with Logging {
override def onExecutorShutdown(): Unit = shutdown()

private def initialize(conf: SparkConf): Unit = {
// Force batch type initializations.
VeloxBatch.getClass
ArrowJavaBatch.getClass
ArrowNativeBatch.getClass

// Sets this configuration only once, since not undoable.
if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false)) {
val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR)
JniWorkspace.enableDebug(debugDir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/
package org.apache.spark.api.python

import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxBatch}
import org.apache.gluten.columnarbatch.ArrowBatches.ArrowJavaBatch
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq}
Expand Down Expand Up @@ -218,13 +218,8 @@ case class ColumnarArrowEvalPythonExec(

override protected def batchType0(): Convention.BatchType = ArrowJavaBatch

// FIXME: Make this accepts ArrowJavaBatch as input. Before doing that, a weight-based
// shortest patch algorithm should be added into transition factory. So that the factory
// can find out row->velox->arrow-native->arrow-java as the possible viable transition.
// Otherwise with current solution, any input (even already in Arrow Java format) will be
// converted into Velox format then into Arrow Java format before entering python runner.
override def requiredChildrenConventions(): Seq[ConventionReq] = List(
ConventionReq.of(ConventionReq.RowType.Any, ConventionReq.BatchType.Is(VeloxBatch)))
ConventionReq.of(ConventionReq.RowType.Any, ConventionReq.BatchType.Is(ArrowJavaBatch)))

override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
Expand Down Expand Up @@ -348,17 +343,17 @@ case class ColumnarArrowEvalPythonExec(
val inputBatchIter = contextAwareIterator.map {
inputCb =>
start_time = System.nanoTime()
val loaded = ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), inputCb)
ColumnarBatches.retain(loaded)
ColumnarBatches.checkLoaded(inputCb)
ColumnarBatches.retain(inputCb)
// 0. cache input for later merge
inputCbCache += loaded
numInputRows += loaded.numRows
inputCbCache += inputCb
numInputRows += inputCb.numRows
// We only need to pass the referred cols data to python worker for evaluation.
var colsForEval = new ArrayBuffer[ColumnVector]()
for (i <- originalOffsets) {
colsForEval += loaded.column(i)
colsForEval += inputCb.column(i)
}
new ColumnarBatch(colsForEval.toArray, loaded.numRows())
new ColumnarBatch(colsForEval.toArray, inputCb.numRows())
}

val outputColumnarBatchIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.gluten.extension.columnar.transition
import org.apache.gluten.backendsapi.velox.VeloxListenerApi
import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch, ArrowNativeBatch}
import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec, RowToVeloxColumnarExec, VeloxColumnarToRowExec}
import org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
import org.apache.gluten.test.MockVeloxBackend
Expand Down Expand Up @@ -64,11 +63,10 @@ class VeloxTransitionSuite extends SharedSparkSession {

test("ArrowNative R2C - requires Arrow input") {
val in = BatchUnary(ArrowNativeBatch, RowLeaf())
assertThrows[GlutenException] {
// No viable transitions.
// FIXME: Support this case.
Transitions.insertTransitions(in, outputsColumnar = false)
}
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
LoadArrowDataExec(BatchUnary(ArrowNativeBatch, RowToVeloxColumnarExec(RowLeaf())))))
}

test("ArrowNative-to-Velox C2C") {
Expand All @@ -92,11 +90,12 @@ class VeloxTransitionSuite extends SharedSparkSession {

test("Vanilla-to-ArrowNative C2C") {
val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VanillaBatch))
assertThrows[GlutenException] {
// No viable transitions.
// FIXME: Support this case.
Transitions.insertTransitions(in, outputsColumnar = false)
}
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
LoadArrowDataExec(BatchUnary(
ArrowNativeBatch,
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
}

test("ArrowNative-to-Vanilla C2C") {
Expand All @@ -121,11 +120,10 @@ class VeloxTransitionSuite extends SharedSparkSession {

test("ArrowJava R2C - requires Arrow input") {
val in = BatchUnary(ArrowJavaBatch, RowLeaf())
assertThrows[GlutenException] {
// No viable transitions.
// FIXME: Support this case.
Transitions.insertTransitions(in, outputsColumnar = false)
}
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
BatchUnary(ArrowJavaBatch, LoadArrowDataExec(RowToVeloxColumnarExec(RowLeaf())))))
}

test("ArrowJava-to-Velox C2C") {
Expand All @@ -146,11 +144,12 @@ class VeloxTransitionSuite extends SharedSparkSession {

test("Vanilla-to-ArrowJava C2C") {
val in = BatchUnary(ArrowJavaBatch, BatchLeaf(VanillaBatch))
assertThrows[GlutenException] {
// No viable transitions.
// FIXME: Support this case.
Transitions.insertTransitions(in, outputsColumnar = false)
}
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
BatchUnary(
ArrowJavaBatch,
LoadArrowDataExec(RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
}

test("ArrowJava-to-Vanilla C2C") {
Expand Down Expand Up @@ -195,8 +194,7 @@ class VeloxTransitionSuite extends SharedSparkSession {
val in = BatchUnary(VanillaBatch, BatchLeaf(VeloxBatch))
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
BatchUnary(VanillaBatch, RowToColumnarExec(VeloxColumnarToRowExec(BatchLeaf(VeloxBatch))))))
out == ColumnarToRowExec(BatchUnary(VanillaBatch, LoadArrowDataExec(BatchLeaf(VeloxBatch)))))
}

override protected def beforeAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ trait Backend {
def onExecutorStart(pc: PluginContext): Unit = {}
def onExecutorShutdown(): Unit = {}

/** The columnar-batch type this backend is using. */
def batchType: Convention.BatchType
/** The columnar-batch type this backend is by default using. */
def defaultBatchType: Convention.BatchType

/**
* Overrides [[org.apache.gluten.extension.columnar.transition.ConventionFunc]] Gluten is using to
* determine the convention (its row-based processing / columnar-batch processing support) of a
* plan with a user-defined function that accepts a plan then returns batch type it outputs.
* plan with a user-defined function that accepts a plan then returns convention type it outputs,
* and input conventions it requires.
*/
def batchTypeFunc(): ConventionFunc.BatchOverride = PartialFunction.empty
def convFuncOverride(): ConventionFunc.Override = ConventionFunc.Override.Empty

/** Query planner rules. */
def injectRules(injector: RuleInjector): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,31 +61,35 @@ object Convention {
Impl(rowType, batchType)
}

sealed trait RowType
sealed trait RowType extends TransitionGraph.Vertex with Serializable {
Transition.graph.addVertex(this)
}

object RowType {
// None indicates that the plan doesn't support row-based processing.
final case object None extends RowType
final case object VanillaRow extends RowType
}

trait BatchType extends Serializable {
final def fromRow(transitionDef: TransitionDef): Unit = {
Transition.factory.update().defineFromRowTransition(this, transitionDef)
trait BatchType extends TransitionGraph.Vertex with Serializable {
Transition.graph.addVertex(this)

final protected def fromRow(transitionDef: TransitionDef): Unit = {
Transition.graph.addEdge(RowType.VanillaRow, this, transitionDef.create())
}

final def toRow(transitionDef: TransitionDef): Unit = {
Transition.factory.update().defineToRowTransition(this, transitionDef)
final protected def toRow(transitionDef: TransitionDef): Unit = {
Transition.graph.addEdge(this, RowType.VanillaRow, transitionDef.create())
}

final def fromBatch(from: BatchType, transitionDef: TransitionDef): Unit = {
final protected def fromBatch(from: BatchType, transitionDef: TransitionDef): Unit = {
assert(from != this)
Transition.factory.update().defineBatchTransition(from, this, transitionDef)
Transition.graph.addEdge(from, this, transitionDef.create())
}

final def toBatch(to: BatchType, transitionDef: TransitionDef): Unit = {
final protected def toBatch(to: BatchType, transitionDef: TransitionDef): Unit = {
assert(to != this)
Transition.factory.update().defineBatchTransition(this, to, transitionDef)
Transition.graph.addEdge(this, to, transitionDef.create())
}
}

Expand Down
Loading

0 comments on commit 255b0cc

Please sign in to comment.