Skip to content

Commit

Permalink
[GLUTEN-7313][VL] Explicit Arrow transitions, part 3: code cleanups (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Sep 29, 2024
1 parent 255b0cc commit d4934d7
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 174 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten.columnarbatch

import org.apache.gluten.extension.columnar.transition.Convention

import org.apache.spark.sql.execution.{CHColumnarToRowExec, RowToCHNativeColumnarExec, SparkPlan}
import org.apache.spark.sql.execution.{CHColumnarToRowExec, RowToCHNativeColumnarExec}

/**
* ClickHouse batch convention.
Expand All @@ -38,15 +38,6 @@ import org.apache.spark.sql.execution.{CHColumnarToRowExec, RowToCHNativeColumna
* }}}
*/
object CHBatch extends Convention.BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToCHNativeColumnarExec(plan)
})

toRow(
() =>
(plan: SparkPlan) => {
CHColumnarToRowExec(plan)
})
fromRow(RowToCHNativeColumnarExec.apply)
toRow(CHColumnarToRowExec.apply)
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,14 @@
*/
package org.apache.gluten.columnarbatch

import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec, RowToVeloxColumnarExec, VeloxColumnarToRowExec}
import org.apache.gluten.extension.columnar.transition.{Convention, TransitionDef}

import org.apache.spark.sql.execution.SparkPlan
import org.apache.gluten.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec}
import org.apache.gluten.extension.columnar.transition.{Convention, Transition}

object VeloxBatch extends Convention.BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToVeloxColumnarExec(plan)
})

toRow(
() =>
(plan: SparkPlan) => {
VeloxColumnarToRowExec(plan)
})

fromRow(RowToVeloxColumnarExec.apply)
toRow(VeloxColumnarToRowExec.apply)
// TODO: Add explicit transitions between Arrow native batch and Velox batch.
// See https://github.com/apache/incubator-gluten/issues/7313.

fromBatch(
ArrowBatches.ArrowJavaBatch,
() =>
(plan: SparkPlan) => {
OffloadArrowDataExec(plan)
})

toBatch(
ArrowBatches.ArrowJavaBatch,
() =>
(plan: SparkPlan) => {
LoadArrowDataExec(plan)
})

fromBatch(
ArrowBatches.ArrowNativeBatch,
() =>
(plan: SparkPlan) => {
LoadArrowDataExec(plan)
})

toBatch(ArrowBatches.ArrowNativeBatch, TransitionDef.empty)
fromBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
toBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ class VeloxTransitionSuite extends SharedSparkSession {
test("ArrowNative C2R - outputs row") {
val in = BatchLeaf(ArrowNativeBatch)
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(out == ColumnarToRowExec(LoadArrowDataExec(BatchLeaf(ArrowNativeBatch))))
assert(out == VeloxColumnarToRowExec(BatchLeaf(ArrowNativeBatch)))
}

test("ArrowNative C2R - requires row input") {
val in = RowUnary(BatchLeaf(ArrowNativeBatch))
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(out == RowUnary(ColumnarToRowExec(LoadArrowDataExec(BatchLeaf(ArrowNativeBatch)))))
assert(out == RowUnary(VeloxColumnarToRowExec(BatchLeaf(ArrowNativeBatch))))
}

test("ArrowNative R2C - requires Arrow input") {
val in = BatchUnary(ArrowNativeBatch, RowLeaf())
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
LoadArrowDataExec(BatchUnary(ArrowNativeBatch, RowToVeloxColumnarExec(RowLeaf())))))
out == VeloxColumnarToRowExec(
BatchUnary(ArrowNativeBatch, RowToVeloxColumnarExec(RowLeaf()))))
}

test("ArrowNative-to-Velox C2C") {
Expand All @@ -75,27 +75,23 @@ class VeloxTransitionSuite extends SharedSparkSession {
// No explicit transition needed for ArrowNative-to-Velox.
// FIXME: Add explicit transitions.
// See https://github.com/apache/incubator-gluten/issues/7313.
assert(
out == VeloxColumnarToRowExec(
BatchUnary(VeloxBatch, LoadArrowDataExec(BatchLeaf(ArrowNativeBatch)))))
assert(out == VeloxColumnarToRowExec(BatchUnary(VeloxBatch, BatchLeaf(ArrowNativeBatch))))
}

test("Velox-to-ArrowNative C2C") {
val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VeloxBatch))
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
LoadArrowDataExec(BatchUnary(ArrowNativeBatch, BatchLeaf(VeloxBatch)))))
assert(out == VeloxColumnarToRowExec(BatchUnary(ArrowNativeBatch, BatchLeaf(VeloxBatch))))
}

test("Vanilla-to-ArrowNative C2C") {
val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VanillaBatch))
val out = Transitions.insertTransitions(in, outputsColumnar = false)
assert(
out == ColumnarToRowExec(
LoadArrowDataExec(BatchUnary(
out == VeloxColumnarToRowExec(
BatchUnary(
ArrowNativeBatch,
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch))))))
}

test("ArrowNative-to-Vanilla C2C") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package org.apache.gluten.columnarbatch

import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec}
import org.apache.gluten.extension.columnar.transition.{Convention, TransitionDef}
import org.apache.gluten.extension.columnar.transition.{Convention, Transition}
import org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch

import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan}

object ArrowBatches {

/**
Expand All @@ -35,13 +33,7 @@ object ArrowBatches {
* implementations.
*/
object ArrowJavaBatch extends Convention.BatchType {
toRow(
() =>
(plan: SparkPlan) => {
ColumnarToRowExec(plan)
})

toBatch(VanillaBatch, TransitionDef.empty)
toBatch(VanillaBatch, Transition.empty)
}

/**
Expand All @@ -52,31 +44,7 @@ object ArrowBatches {
* [[ColumnarBatches]].
*/
object ArrowNativeBatch extends Convention.BatchType {
toRow(
() =>
(plan: SparkPlan) => {
ColumnarToRowExec(LoadArrowDataExec(plan))
})

toBatch(
VanillaBatch,
() =>
(plan: SparkPlan) => {
LoadArrowDataExec(plan)
})

fromBatch(
ArrowJavaBatch,
() =>
(plan: SparkPlan) => {
OffloadArrowDataExec(plan)
})

toBatch(
ArrowJavaBatch,
() =>
(plan: SparkPlan) => {
LoadArrowDataExec(plan)
})
fromBatch(ArrowJavaBatch, OffloadArrowDataExec.apply)
toBatch(ArrowJavaBatch, LoadArrowDataExec.apply)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,40 +74,31 @@ object Convention {
trait BatchType extends TransitionGraph.Vertex with Serializable {
Transition.graph.addVertex(this)

final protected def fromRow(transitionDef: TransitionDef): Unit = {
Transition.graph.addEdge(RowType.VanillaRow, this, transitionDef.create())
final protected def fromRow(transition: Transition): Unit = {
Transition.graph.addEdge(RowType.VanillaRow, this, transition)
}

final protected def toRow(transitionDef: TransitionDef): Unit = {
Transition.graph.addEdge(this, RowType.VanillaRow, transitionDef.create())
final protected def toRow(transition: Transition): Unit = {
Transition.graph.addEdge(this, RowType.VanillaRow, transition)
}

final protected def fromBatch(from: BatchType, transitionDef: TransitionDef): Unit = {
final protected def fromBatch(from: BatchType, transition: Transition): Unit = {
assert(from != this)
Transition.graph.addEdge(from, this, transitionDef.create())
Transition.graph.addEdge(from, this, transition)
}

final protected def toBatch(to: BatchType, transitionDef: TransitionDef): Unit = {
final protected def toBatch(to: BatchType, transition: Transition): Unit = {
assert(to != this)
Transition.graph.addEdge(this, to, transitionDef.create())
Transition.graph.addEdge(this, to, transition)
}
}

object BatchType {
// None indicates that the plan doesn't support batch-based processing.
final case object None extends BatchType
final case object VanillaBatch extends BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToColumnarExec(plan)
})

toRow(
() =>
(plan: SparkPlan) => {
ColumnarToRowExec(plan)
})
fromRow(RowToColumnarExec.apply)
toRow(ColumnarToRowExec.apply)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,6 @@ trait Transition {
protected def apply0(plan: SparkPlan): SparkPlan
}

trait TransitionDef {
def create(): Transition
}

object TransitionDef {
val empty: TransitionDef = () => Transition.empty
}

object Transition {
val empty: Transition = (plan: SparkPlan) => plan
private val abort: Transition = (_: SparkPlan) => throw new UnsupportedOperationException("Abort")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,59 +85,20 @@ class TransitionSuite extends SharedSparkSession {

object TransitionSuite extends TransitionSuiteBase {
object TypeA extends Convention.BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToBatch(this, plan)
})

toRow(
() =>
(plan: SparkPlan) => {
BatchToRow(this, plan)
})
fromRow(RowToBatch(this, _))
toRow(BatchToRow(this, _))
}

object TypeB extends Convention.BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToBatch(this, plan)
})

toRow(
() =>
(plan: SparkPlan) => {
BatchToRow(this, plan)
})
fromRow(RowToBatch(this, _))
toRow(BatchToRow(this, _))
}

object TypeC extends Convention.BatchType {
fromRow(
() =>
(plan: SparkPlan) => {
RowToBatch(this, plan)
})

toRow(
() =>
(plan: SparkPlan) => {
BatchToRow(this, plan)
})

fromBatch(
TypeA,
() =>
(plan: SparkPlan) => {
BatchToBatch(TypeA, this, plan)
})

toBatch(
TypeA,
() =>
(plan: SparkPlan) => {
BatchToBatch(this, TypeA, plan)
})
fromRow(RowToBatch(this, _))
toRow(BatchToRow(this, _))
fromBatch(TypeA, BatchToBatch(TypeA, this, _))
toBatch(TypeA, BatchToBatch(this, TypeA, _))
}

object TypeD extends Convention.BatchType {}
Expand Down

0 comments on commit d4934d7

Please sign in to comment.