Skip to content

Commit

Permalink
gluten-it exclude fixup
Browse files Browse the repository at this point in the history
gluten-it exclude
  • Loading branch information
zhztheplayer committed Mar 26, 2024
1 parent d671c1e commit 6e84e2f
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class TpcMixin {

@CommandLine.Option(names = {"--log-level"}, description = "Set log level: 0 for DEBUG, 1 for INFO, 2 for WARN", defaultValue = "2")
private int logLevel;

@CommandLine.Option(names = {"--error-on-memleak"}, description = "Fail the test when memory leak is detected by Spark's memory manager", defaultValue = "false")
private boolean errorOnMemLeak;

Expand Down Expand Up @@ -152,7 +153,7 @@ public Integer runActions(Action[] actions) {
return 0;
}

private <K,V> Map<K, V> mergeMapSafe(Map<K, V> conf, Map<? extends K, ? extends V> other) {
private <K, V> Map<K, V> mergeMapSafe(Map<K, V> conf, Map<? extends K, ? extends V> other) {
other.keySet().forEach(k -> {
if (conf.containsKey(k)) {
throw new IllegalArgumentException("Key already exists in conf: " + k);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public class Parameterized implements Callable<Integer> {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];

@CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
private String[] excludedQueries = new String[0];

@CommandLine.Option(names = {"--iterations"}, description = "How many iterations to run", defaultValue = "1")
private int iterations;

Expand Down Expand Up @@ -119,7 +122,7 @@ public Integer call() throws Exception {
)).collect(Collectors.toList())).asScala();

io.glutenproject.integration.tpc.action.Parameterized parameterized =
new io.glutenproject.integration.tpc.action.Parameterized(dataGenMixin.getScale(), this.queries, iterations, warmupIterations, parsedDims, metrics);
new io.glutenproject.integration.tpc.action.Parameterized(dataGenMixin.getScale(), this.queries, excludedQueries, iterations, warmupIterations, parsedDims, metrics);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), parameterized));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class Queries implements Callable<Integer> {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];

@CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
private String[] excludedQueries = new String[0];

@CommandLine.Option(names = {"--explain"}, description = "Output explain result for queries", defaultValue = "false")
private boolean explain;

Expand All @@ -47,7 +50,7 @@ public class Queries implements Callable<Integer> {
@Override
public Integer call() throws Exception {
io.glutenproject.integration.tpc.action.Queries queries =
new io.glutenproject.integration.tpc.action.Queries(dataGenMixin.getScale(), this.queries, explain, iterations, randomKillTasks);
new io.glutenproject.integration.tpc.action.Queries(dataGenMixin.getScale(), this.queries, this.excludedQueries, explain, iterations, randomKillTasks);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), queries));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ public class QueriesCompare implements Callable<Integer> {
@CommandLine.Option(names = {"--queries"}, description = "Set a comma-separated list of query IDs to run, run all queries if not specified. Example: --queries=q1,q6", split = ",")
private String[] queries = new String[0];

@CommandLine.Option(names = {"--excluded-queries"}, description = "Set a comma-separated list of query IDs to exclude. Example: --exclude-queries=q1,q6", split = ",")
private String[] excludedQueries = new String[0];

@CommandLine.Option(names = {"--explain"}, description = "Output explain result for queries", defaultValue = "false")
private boolean explain;

Expand All @@ -44,7 +47,7 @@ public class QueriesCompare implements Callable<Integer> {
@Override
public Integer call() throws Exception {
io.glutenproject.integration.tpc.action.QueriesCompare queriesCompare =
new io.glutenproject.integration.tpc.action.QueriesCompare(dataGenMixin.getScale(), this.queries, explain, iterations);
new io.glutenproject.integration.tpc.action.QueriesCompare(dataGenMixin.getScale(), this.queries, this.excludedQueries, explain, iterations);
return mixin.runActions(ArrayUtils.addAll(dataGenMixin.makeActions(), queriesCompare));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,36 @@ abstract class TpcSuite(
private[tpc] def desc(): String

}

object TpcSuite {
implicit class TpcSuiteImplicits(suite: TpcSuite) {
def selectQueryIds(queryIds: Array[String], excludedQueryIds: Array[String]): Array[String] = {
if (queryIds.nonEmpty && excludedQueryIds.nonEmpty) {
throw new IllegalArgumentException(
"Should not specify queries and excluded queries at the same time")
}
val all = suite.allQueryIds()
val allSet = all.toSet
if (queryIds.nonEmpty) {
assert(
queryIds.forall(id => allSet.contains(id)),
"Invalid query ID: " + queryIds.collectFirst {
case id if !allSet.contains(id)=>
id
}.get)
return queryIds
}
if (excludedQueryIds.nonEmpty) {
assert(
excludedQueryIds.forall(id => allSet.contains(id)),
"Invalid query ID to exclude: " + excludedQueryIds.collectFirst {
case id if !allSet.contains(id)=>
id
}.get)
val excludedSet = excludedQueryIds.toSet
return all.filterNot(excludedSet.contains)
}
all
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import scala.collection.mutable.ArrayBuffer
class Parameterized(
scale: Double,
queryIds: Array[String],
excludedQueryIds: Array[String],
iterations: Int,
warmupIterations: Int,
configDimensions: Seq[Dim],
Expand Down Expand Up @@ -104,19 +105,7 @@ class Parameterized(
sessionSwitcher.registerSession(coordinate.toString, conf)
}

val runQueryIds = queryIds match {
case Array() =>
allQueries
case _ =>
queryIds
}
val allQueriesSet = allQueries.toSet
runQueryIds.foreach {
queryId =>
if (!allQueriesSet.contains(queryId)) {
throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
}
}
val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)

// warm up
(0 until warmupIterations).foreach {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,32 @@ package io.glutenproject.integration.tpc.action
import io.glutenproject.integration.stat.RamStat
import io.glutenproject.integration.tpc.{TpcRunner, TpcSuite}

import org.apache.spark.sql.SparkSessionSwitcher

import org.apache.commons.lang3.exception.ExceptionUtils

case class Queries(scale: Double, queryIds: Array[String], explain: Boolean, iterations: Int, randomKillTasks: Boolean)
case class Queries(
scale: Double,
queryIds: Array[String],
excludedQueryIds: Array[String],
explain: Boolean,
iterations: Int,
randomKillTasks: Boolean)
extends Action {

override def execute(tpcSuite: TpcSuite): Boolean = {
val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(), tpcSuite.dataWritePath(scale))
val allQueries = tpcSuite.allQueryIds()
val results = (0 until iterations).flatMap {
iteration =>
println(s"Running tests (iteration $iteration)...")
val runQueryIds = queryIds match {
case Array() =>
allQueries
case _ =>
queryIds
}
val allQueriesSet = allQueries.toSet
runQueryIds.map {
queryId =>
if (!allQueriesSet.contains(queryId)) {
throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
}
Queries.runTpcQuery(runner, tpcSuite.sessionSwitcher, queryId, tpcSuite.desc(), explain, randomKillTasks)
Queries.runTpcQuery(
runner,
tpcSuite.sessionSwitcher,
queryId,
tpcSuite.desc(),
explain,
randomKillTasks)
}
}.toList

Expand Down Expand Up @@ -147,13 +147,24 @@ object Queries {
)))
}

private def runTpcQuery(runner: _root_.io.glutenproject.integration.tpc.TpcRunner, sessionSwitcher: _root_.org.apache.spark.sql.SparkSessionSwitcher, id: _root_.java.lang.String, desc: _root_.java.lang.String, explain: Boolean, randomKillTasks: Boolean) = {
private def runTpcQuery(
runner: _root_.io.glutenproject.integration.tpc.TpcRunner,
sessionSwitcher: _root_.org.apache.spark.sql.SparkSessionSwitcher,
id: _root_.java.lang.String,
desc: _root_.java.lang.String,
explain: Boolean,
randomKillTasks: Boolean) = {
println(s"Running query: $id...")
try {
val testDesc = "Gluten Spark %s %s".format(desc, id)
sessionSwitcher.useSession("test", testDesc)
runner.createTables(sessionSwitcher.spark())
val result = runner.runTpcQuery(sessionSwitcher.spark(), testDesc, id, explain = explain, randomKillTasks = randomKillTasks)
val result = runner.runTpcQuery(
sessionSwitcher.spark(),
testDesc,
id,
explain = explain,
randomKillTasks = randomKillTasks)
val resultRows = result.rows
println(
s"Successfully ran query $id. " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,22 @@ import org.apache.spark.sql.{SparkSessionSwitcher, TestUtils}

import org.apache.commons.lang3.exception.ExceptionUtils

case class QueriesCompare(scale: Double, queryIds: Array[String], explain: Boolean, iterations: Int)
case class QueriesCompare(
scale: Double,
queryIds: Array[String],
excludedQueryIds: Array[String],
explain: Boolean,
iterations: Int)
extends Action {

override def execute(tpcSuite: TpcSuite): Boolean = {
val runner: TpcRunner = new TpcRunner(tpcSuite.queryResource(), tpcSuite.dataWritePath(scale))
val allQueries = tpcSuite.allQueryIds()
val runQueryIds = tpcSuite.selectQueryIds(queryIds, excludedQueryIds)
val results = (0 until iterations).flatMap {
iteration =>
println(s"Running tests (iteration $iteration)...")
val runQueryIds = queryIds match {
case Array() =>
allQueries
case _ =>
queryIds
}
val allQueriesSet = allQueries.toSet
runQueryIds.map {
queryId =>
if (!allQueriesSet.contains(queryId)) {
throw new IllegalArgumentException(s"Query ID doesn't exist: $queryId")
}
QueriesCompare.runTpcQuery(
queryId,
explain,
Expand Down

0 comments on commit 6e84e2f

Please sign in to comment.