Skip to content

Commit

Permalink
add benchmark ut for Spark TRowSet generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 3, 2023
1 parent 88c8cc3 commit 52c45c9
Showing 1 changed file with 93 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@ import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import scala.collection.JavaConverters._
import scala.math.BigDecimal.RoundingMode

import org.apache.commons.lang3.time.StopWatch
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand All @@ -46,7 +49,7 @@ class RowSetSuite extends KyuubiFunSuite {
val doubleVal = java.lang.Double.valueOf(s"$value.$value")
val stringVal = value.toString * value
val decimalVal = new java.math.BigDecimal(s"$value.$value")
val day = java.lang.String.format("%02d", java.lang.Integer.valueOf(value + 1))
val day = java.lang.String.format("%02d", java.lang.Integer.valueOf(value % 30 + 1))
val dateVal = Date.valueOf(s"2018-11-$day")
val timestampVal = Timestamp.valueOf(s"2018-11-17 13:33:33.$value")
val binaryVal = Array.fill[Byte](value)(value.toByte)
Expand Down Expand Up @@ -76,24 +79,32 @@ class RowSetSuite extends KyuubiFunSuite {
instant)
}

val schema: StructType = new StructType()
.add("a", "boolean")
.add("b", "tinyint")
.add("c", "smallint")
.add("d", "int")
.add("e", "bigint")
.add("f", "float")
.add("g", "double")
.add("h", "string")
.add("i", "decimal")
.add("j", "date")
.add("k", "timestamp")
.add("l", "binary")
.add("m", "array<double>")
.add("n", "map<int, double>")
.add("o", "interval")
.add("p", "date")
.add("q", "timestamp")
val schemaStructFields: Seq[StructField] = Seq(
("a", "boolean"),
("b", "tinyint"),
("c", "smallint"),
("d", "int"),
("e", "bigint"),
("f", "float"),
("g", "double"),
("h", "string"),
("i", "decimal"),
("j", "date"),
("k", "timestamp"),
("l", "binary"),
("m", "array<double>"),
("n", "map<int, double>"),
("o", "interval"),
("p", "date"),
("q", "timestamp")).map {
case (colName, typeName) => StructField(colName, CatalystSqlParser.parseDataType(typeName))
}

val schema: StructType = {
var st = new StructType()
schemaStructFields.foreach(aa => st = st.add(aa))
st
}

private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))

Expand Down Expand Up @@ -259,4 +270,67 @@ class RowSetSuite extends KyuubiFunSuite {
}
}
}

test("to row set benchmark") {
val rowCount = 10000
val allRows = (0 until rowCount).map(genRow)

def benchmarkToTRowSet(
clue: String,
rows: Seq[Row],
schema: StructType,
protocolVersion: TProtocolVersion): Unit = {
val sw = StopWatch.createStarted()
RowSet.toTRowSet(rows, schema, protocolVersion)
sw.stop()
val msTimeCost: BigDecimal = (BigDecimal(sw.getNanoTime) / BigDecimal(1000000))
.setScale(3, RoundingMode.HALF_UP)
val rowsPerMilliSecond: BigDecimal = (BigDecimal(rows.size) / msTimeCost)
.setScale(3, RoundingMode.HALF_UP)
// scalastyle:off
printf(
"%20s %20s %20s\n",
clue,
s"$msTimeCost ms",
s"$rowsPerMilliSecond rows/ms")
// scalastyle:on
}

def singleColumn(field: StructField, index: Int, protocolVersion: TProtocolVersion): Unit = {
benchmarkToTRowSet(
field.dataType.typeName,
allRows.map(row => Row(row.get(index))).asInstanceOf[Seq[Row]],
StructType(Seq(field)),
protocolVersion)
}

Seq(
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V5,
TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6)
.foreach { protocolVersion =>
val mode =
if (protocolVersion.getValue < TProtocolVersion.HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
"row-based"
} else {
"column-based"
}
// scalastyle:off
println(s"Start testing $mode RowSet.toTRowSet over $protocolVersion with $rowCount rows.")
printf(
"%20s %20s %20s\n",
"Type(s)",
"Time Cost",
"Rows/ms")
// scalastyle:on

benchmarkToTRowSet("with all types", allRows, schema, protocolVersion)
schemaStructFields.zipWithIndex.foreach { case (field, index) =>
singleColumn(field, index, protocolVersion)
}

// scalastyle:off
println()
// scalastyle:on
}
}
}

0 comments on commit 52c45c9

Please sign in to comment.