Skip to content

Commit

Permalink
add benchmark ut for Spark TRowSet generation
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenliang123 committed Dec 4, 2023
1 parent 25eb53d commit 3619589
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 65 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.spark.schema

import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

trait RowSetHelper {
protected def genRow(value: Int): Row = {
val boolVal = value % 3 match {
case 0 => true
case 1 => false
case _ => null
}
val byteVal = value.toByte
val shortVal = value.toShort
val longVal = value.toLong
val floatVal = java.lang.Float.valueOf(s"$value.$value")
val doubleVal = java.lang.Double.valueOf(s"$value.$value")
val stringVal = value.toString * value
val decimalVal = new java.math.BigDecimal(s"$value.$value")
val day = java.lang.String.format("%02d", java.lang.Integer.valueOf(value % 30 + 1))
val dateVal = Date.valueOf(s"2018-11-$day")
val timestampVal = Timestamp.valueOf(s"2018-11-17 13:33:33.$value")
val binaryVal = Array.fill[Byte](value)(value.toByte)
val arrVal = Array.fill(value)(doubleVal).toSeq
val mapVal = Map(value -> doubleVal)
val interval = new CalendarInterval(value, value, value)
val localDate = LocalDate.of(2018, 11, 17)
val instant = Instant.now()

Row(
boolVal,
byteVal,
shortVal,
value,
longVal,
floatVal,
doubleVal,
stringVal,
decimalVal,
dateVal,
timestampVal,
binaryVal,
arrVal,
mapVal,
interval,
localDate,
instant)
}

protected val schemaStructFields: Seq[StructField] = Seq(
("a", "boolean", "boolVal"),
("b", "tinyint", "byteVal"),
("c", "smallint", "shortVal"),
("d", "int", "value"),
("e", "bigint", "longVal"),
("f", "float", "floatVal"),
("g", "double", "doubleVal"),
("h", "string", "stringVal"),
("i", "decimal", "decimalVal"),
("j", "date", "dateVal"),
("k", "timestamp", "timestampVal"),
("l", "binary", "binaryVal"),
("m", "array<double>", "arrVal"),
("n", "map<int, double>", "mapVal"),
("o", "interval", "interval"),
("p", "date", "localDate"),
("q", "timestamp", "instant"))
.map { case (colName, typeName, comment) =>
StructField(colName, CatalystSqlParser.parseDataType(typeName)).withComment(comment)
}

protected val schema: StructType = StructType(schemaStructFields)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.kyuubi.engine.spark.schema
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.{Instant, LocalDate}

import scala.collection.JavaConverters._

Expand All @@ -31,70 +30,7 @@ import org.apache.spark.unsafe.types.CalendarInterval

import org.apache.kyuubi.KyuubiFunSuite

class RowSetSuite extends KyuubiFunSuite {

def genRow(value: Int): Row = {
val boolVal = value % 3 match {
case 0 => true
case 1 => false
case _ => null
}
val byteVal = value.toByte
val shortVal = value.toShort
val longVal = value.toLong
val floatVal = java.lang.Float.valueOf(s"$value.$value")
val doubleVal = java.lang.Double.valueOf(s"$value.$value")
val stringVal = value.toString * value
val decimalVal = new java.math.BigDecimal(s"$value.$value")
val day = java.lang.String.format("%02d", java.lang.Integer.valueOf(value + 1))
val dateVal = Date.valueOf(s"2018-11-$day")
val timestampVal = Timestamp.valueOf(s"2018-11-17 13:33:33.$value")
val binaryVal = Array.fill[Byte](value)(value.toByte)
val arrVal = Array.fill(value)(doubleVal).toSeq
val mapVal = Map(value -> doubleVal)
val interval = new CalendarInterval(value, value, value)
val localDate = LocalDate.of(2018, 11, 17)
val instant = Instant.now()

Row(
boolVal,
byteVal,
shortVal,
value,
longVal,
floatVal,
doubleVal,
stringVal,
decimalVal,
dateVal,
timestampVal,
binaryVal,
arrVal,
mapVal,
interval,
localDate,
instant)
}

val schema: StructType = new StructType()
.add("a", "boolean")
.add("b", "tinyint")
.add("c", "smallint")
.add("d", "int")
.add("e", "bigint")
.add("f", "float")
.add("g", "double")
.add("h", "string")
.add("i", "decimal")
.add("j", "date")
.add("k", "timestamp")
.add("l", "binary")
.add("m", "array<double>")
.add("n", "map<int, double>")
.add("o", "interval")
.add("p", "date")
.add("q", "timestamp")

class RowSetSuite extends KyuubiFunSuite with RowSetHelper {
private val rows: Seq[Row] = (0 to 10).map(genRow) ++ Seq(Row.fromSeq(Seq.fill(17)(null)))

test("column based set") {
Expand Down Expand Up @@ -259,4 +195,5 @@ class RowSetSuite extends KyuubiFunSuite {
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.engine.spark.schema

import scala.math.BigDecimal.RoundingMode

import org.apache.commons.lang3.time.StopWatch
import org.apache.hive.service.rpc.thrift.TProtocolVersion
import org.apache.hive.service.rpc.thrift.TProtocolVersion._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._

import org.apache.kyuubi.KyuubiFunSuite

/**
* Benchmark to measure the performance of generate TRowSet.
*
* {{{
* RUN_BENCHMARK=1 ./build/mvn clean test \
* -pl externals/kyuubi-spark-sql-engine -am \
* -Dtest=none -DwildcardSuites=org.apache.kyuubi.engine.spark.schema.TRowSetBenchmark
* }}}
*/
class TRowSetBenchmark extends KyuubiFunSuite with RowSetHelper {
private val runBenchmark = sys.env.contains("RUN_BENCHMARK")

private val rowCount = 3000
private lazy val allRows = (0 until rowCount).map(genRow)

test("row-based toTRowSet benchmark") {
assume(runBenchmark)
tRowSetGenerationBenchmark(HIVE_CLI_SERVICE_PROTOCOL_V5)
}

test("column-based toTRowSet benchmark") {
assume(runBenchmark)
tRowSetGenerationBenchmark(HIVE_CLI_SERVICE_PROTOCOL_V6)
}

private def tRowSetGenerationBenchmark(protocolVersion: TProtocolVersion): Unit = {
val rowSetType = if (protocolVersion.getValue < HIVE_CLI_SERVICE_PROTOCOL_V6.getValue) {
"row-based"
} else {
"column-based"
}
// scalastyle:off
println(
s"Benchmark result for $rowSetType RowSet.toTRowSet on $protocolVersion with $rowCount rows.")
printf("%20s %20s %20s\n", "Type(s)", "Time Cost", "Rows/ms")
// scalastyle:on

val totalMs = schemaStructFields.zipWithIndex.map { case (field, idx) =>
// run benchmark with rows of single column with one data type
val rowsOfSingleType = allRows.map(row => Row(row.get(idx)))
val schemaOfSingleType = StructType(Seq(field))
benchmarkToTRowSet(
field.getComment().getOrElse(field.dataType.typeName),
rowsOfSingleType,
schemaOfSingleType,
protocolVersion)
}.sum
val totalRowsPerMs: BigDecimal = (BigDecimal(rowCount) / totalMs)
.setScale(3, RoundingMode.HALF_UP)
// scalastyle:off
println()
printf("%20s %20s %20s\n", "sum(all types)", s"$totalMs ms", s"$totalRowsPerMs rows/ms")

// run benchmark with rows of columns with all data types
benchmarkToTRowSet("with all types", allRows, schema, protocolVersion)

println()
println()
// scalastyle:on
}

private def benchmarkToTRowSet(
clue: String,
rows: Seq[Row],
schema: StructType,
protocolVersion: TProtocolVersion): BigDecimal = {
val sw = StopWatch.createStarted()
RowSet.toTRowSet(rows, schema, protocolVersion)
sw.stop()
val msTimeCost: BigDecimal = (BigDecimal(sw.getNanoTime) / BigDecimal(1000000))
.setScale(3, RoundingMode.HALF_UP)
val rowsPerMilliSecond: BigDecimal = (BigDecimal(rows.size) / msTimeCost)
.setScale(3, RoundingMode.HALF_UP)
// scalastyle:off
printf(
"%20s %20s %20s\n",
clue,
s"$msTimeCost ms",
s"$rowsPerMilliSecond rows/ms")
// scalastyle:on
msTimeCost
}
}

0 comments on commit 3619589

Please sign in to comment.