Skip to content

Commit

Permalink
[VL] Add some fixes following #8355 (#8373)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Dec 31, 2024
1 parent 43b69b6 commit afee97a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.gluten.config.GlutenConfig

import org.apache.spark.SparkConf
import org.apache.spark.sql.{DataFrame, Row, TestUtils}
import org.apache.spark.sql.execution.FormattedMode
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, FormattedMode}

import org.apache.commons.io.FileUtils

Expand Down Expand Up @@ -117,133 +117,133 @@ abstract class VeloxTPCHSuite extends VeloxTPCHTableSupport {
}

test("TPC-H q1") {
runTPCHQuery(1, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(1, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 1)
}
}

test("TPC-H q2") {
runTPCHQuery(2, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(2, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
_ => // due to tpc-h q2 will generate multiple plans, skip checking golden file for now
}
}

test("TPC-H q3") {
runTPCHQuery(3, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(3, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 3)
}
}

test("TPC-H q4") {
runTPCHQuery(4, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(4, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 4)
}
}

test("TPC-H q5") {
runTPCHQuery(5, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(5, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 5)
}
}

test("TPC-H q6") {
runTPCHQuery(6, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(6, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 6)
}
}

test("TPC-H q7") {
runTPCHQuery(7, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(7, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 7)
}
}

test("TPC-H q8") {
runTPCHQuery(8, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(8, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 8)
}
}

test("TPC-H q9") {
runTPCHQuery(9, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(9, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 9)
}
}

test("TPC-H q10") {
runTPCHQuery(10, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(10, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 10)
}
}

test("TPC-H q11") {
runTPCHQuery(11, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(11, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 11)
}
}

test("TPC-H q12") {
runTPCHQuery(12, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(12, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 12)
}
}

test("TPC-H q13") {
runTPCHQuery(13, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(13, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 13)
}
}

test("TPC-H q14") {
runTPCHQuery(14, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(14, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 14)
}
}

test("TPC-H q15") {
runTPCHQuery(15, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(15, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 15)
}
}

test("TPC-H q16") {
runTPCHQuery(16, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(16, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 16)
}
}

test("TPC-H q17") {
runTPCHQuery(17, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(17, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 17)
}
}

test("TPC-H q18") {
runTPCHQuery(18, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(18, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 18)
}
}

test("TPC-H q19") {
runTPCHQuery(19, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(19, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 19)
}
}

test("TPC-H q20") {
runTPCHQuery(20, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(20, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 20)
}
}

test("TPC-H q21") {
runTPCHQuery(21, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(21, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 21)
}
}

test("TPC-H q22") {
runTPCHQuery(22, tpchQueries, queriesResults, compareResult = false, noFallBack = false) {
runTPCHQuery(22, tpchQueries, queriesResults, compareResult = false, noFallBack = true) {
checkGoldenFile(_, 22)
}
}
Expand Down Expand Up @@ -304,6 +304,21 @@ class VeloxTPCHV1GlutenShuffleManagerSuite extends VeloxTPCHSuite {
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set("spark.shuffle.manager", "org.apache.spark.shuffle.GlutenShuffleManager")
}

override protected def runQueryAndCompare(
sqlStr: String,
compareResult: Boolean,
noFallBack: Boolean,
cache: Boolean)(customCheck: DataFrame => Unit): DataFrame = {
assert(noFallBack)
super.runQueryAndCompare(sqlStr, compareResult, noFallBack, cache) {
df =>
assert(df.queryExecution.executedPlan.collect {
case p if p.isInstanceOf[ColumnarShuffleExchangeExec] => p
}.nonEmpty)
customCheck(df)
}
}
}

class VeloxTPCHV1BhjSuite extends VeloxTPCHSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
Expand Down Expand Up @@ -93,22 +92,6 @@ case class ColumnarShuffleExchangeExec(
useSortBasedShuffle)
}

// 'shuffleDependency' is only needed when enable AQE.
// Columnar shuffle will use 'columnarShuffleDependency'
@transient
lazy val shuffleDependency: ShuffleDependency[Int, InternalRow, InternalRow] =
new ShuffleDependency[Int, InternalRow, InternalRow](
_rdd = new ColumnarShuffleExchangeExec.DummyPairRDDWithPartitions(
sparkContext,
inputColumnarRDD.getNumPartitions),
partitioner = columnarShuffleDependency.partitioner
) {

override val shuffleId: Int = columnarShuffleDependency.shuffleId

override val shuffleHandle: ShuffleHandle = columnarShuffleDependency.shuffleHandle
}

// super.stringArgs ++ Iterator(output.map(o => s"${o}#${o.dataType.simpleString}"))
val serializer: Serializer = BackendsApiManager.getSparkPlanExecApiInstance
.createColumnarBatchSerializer(schema, metrics, useSortBasedShuffle)
Expand All @@ -128,9 +111,9 @@ case class ColumnarShuffleExchangeExec(

override def nodeName: String = "ColumnarExchange"

override def numMappers: Int = shuffleDependency.rdd.getNumPartitions
override def numMappers: Int = inputColumnarRDD.getNumPartitions

override def numPartitions: Int = shuffleDependency.partitioner.numPartitions
override def numPartitions: Int = columnarShuffleDependency.partitioner.numPartitions

override def runtimeStatistics: Statistics = {
val dataSize = metrics("dataSize").value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class GlutenConfig(conf: SQLConf) extends Logging {
def isUseGlutenShuffleManager: Boolean =
conf
.getConfString("spark.shuffle.manager", "sort")
.equals("org.apache.spark.shuffle.sort.GlutenShuffleManager")
.equals("org.apache.spark.shuffle.GlutenShuffleManager")

// Whether to use ColumnarShuffleManager.
def isUseColumnarShuffleManager: Boolean =
Expand Down

0 comments on commit afee97a

Please sign in to comment.