diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala index b98d36aab..b853dc482 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/profiling/ProfileClassWarehouse.scala @@ -292,11 +292,11 @@ case class UnsupportedOpsProfileResult(appIndex: Int, case class AppInfoProfileResults(appIndex: Int, appName: String, appId: Option[String], sparkUser: String, startTime: Long, endTime: Option[Long], duration: Option[Long], - durationStr: String, sparkVersion: String, + durationStr: String, sparkRuntime: String, sparkVersion: String, pluginEnabled: Boolean) extends ProfileResult { override val outputHeaders = Seq("appIndex", "appName", "appId", "sparkUser", "startTime", "endTime", "duration", "durationStr", - "sparkVersion", "pluginEnabled") + "sparkRuntime", "sparkVersion", "pluginEnabled") def endTimeToStr: String = { endTime match { @@ -315,13 +315,14 @@ case class AppInfoProfileResults(appIndex: Int, appName: String, override def convertToSeq: Seq[String] = { Seq(appIndex.toString, appName, appId.getOrElse(""), sparkUser, startTime.toString, endTimeToStr, durToStr, - durationStr, sparkVersion, pluginEnabled.toString) + durationStr, sparkRuntime, sparkVersion, pluginEnabled.toString) } override def convertToCSVSeq: Seq[String] = { Seq(appIndex.toString, StringUtils.reformatCSVString(appName), StringUtils.reformatCSVString(appId.getOrElse("")), StringUtils.reformatCSVString(sparkUser), startTime.toString, endTimeToStr, durToStr, StringUtils.reformatCSVString(durationStr), - StringUtils.reformatCSVString(sparkVersion), pluginEnabled.toString) + StringUtils.reformatCSVString(sparkRuntime), StringUtils.reformatCSVString(sparkVersion), + pluginEnabled.toString) } } diff --git a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala index 1da665b1f..66277551a 100644 --- a/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala +++ b/core/src/main/scala/com/nvidia/spark/rapids/tool/views/InformationView.scala @@ -29,7 +29,7 @@ trait AppInformationViewTrait extends ViewableTrait[AppInfoProfileResults] { app.appMetaData.map { a => AppInfoProfileResults(index, a.appName, a.appId, a.sparkUser, a.startTime, a.endTime, app.getAppDuration, - a.getDurationString, app.sparkVersion, app.gpuMode) + a.getDurationString, app.sparkRuntime.toString, app.sparkVersion, app.gpuMode) }.toSeq } override def sortView(rows: Seq[AppInfoProfileResults]): Seq[AppInfoProfileResults] = { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala index e3313b832..da70afd91 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala @@ -37,7 +37,7 @@ import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo} import org.apache.spark.sql.execution.SparkPlanInfo import org.apache.spark.sql.execution.ui.SparkPlanGraphNode import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager} -import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source} +import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, SparkRuntime, ToolsPlanGraph, UTF8Source} import org.apache.spark.util.Utils abstract class AppBase( @@ -475,6 +475,7 @@ abstract class AppBase( protected def postCompletion(): Unit = { registerAttemptId() calculateAppDuration() + setSparkRuntime() } /** @@ -485,6 +486,19 @@ abstract class AppBase( processEventsInternal() postCompletion() } + + /** + * Sets the spark runtime based on the properties of the application. + */ + private def setSparkRuntime(): Unit = { + sparkRuntime = if (isPhoton) { + SparkRuntime.PHOTON + } else if (gpuMode) { + SparkRuntime.SPARK_RAPIDS + } else { + SparkRuntime.SPARK + } + } } object AppBase { diff --git a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala index bbdffc1d7..9c3f3abf7 100644 --- a/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala +++ b/core/src/main/scala/org/apache/spark/sql/rapids/tool/util/CacheablePropsHandler.scala @@ -25,6 +25,15 @@ import org.apache.spark.scheduler.{SparkListenerEnvironmentUpdate, SparkListener import org.apache.spark.sql.rapids.tool.AppEventlogProcessException import org.apache.spark.util.Utils.REDACTION_REPLACEMENT_TEXT + +/** + * Enum to represent different spark runtimes. + */ +object SparkRuntime extends Enumeration { + type SparkRuntime = Value + val SPARK, SPARK_RAPIDS, PHOTON = Value +} + // Handles updating and caching Spark Properties for a Spark application. // Properties stored in this container can be accessed to make decision about certain analysis // that depends on the context of the Spark properties. @@ -68,10 +77,12 @@ trait CacheablePropsHandler { // caches the spark-version from the eventlogs var sparkVersion: String = "" + // caches the spark runtime based on the application properties + var sparkRuntime: SparkRuntime.Value = SparkRuntime.SPARK var gpuMode = false // A flag whether hive is enabled or not. Note that we assume that the // property is global to the entire application once it is set. a.k.a, it cannot be disabled - // once it is was set to true. + // once it was set to true. var hiveEnabled = false // Indicates the ML eventlogType (i.e., Scala or pyspark). It is set only when MLOps are detected. // By default, it is empty. diff --git a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala index 75cd86df5..2f2436fe4 100644 --- a/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala +++ b/core/src/test/scala/com/nvidia/spark/rapids/tool/profiling/ApplicationInfoSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.resource.ResourceProfile import org.apache.spark.sql.{SparkSession, TrampolineUtil} import org.apache.spark.sql.rapids.tool.profiling._ -import org.apache.spark.sql.rapids.tool.util.FSUtils +import org.apache.spark.sql.rapids.tool.util.{FSUtils, SparkRuntime} class ApplicationInfoSuite extends FunSuite with Logging { @@ -1115,4 +1115,18 @@ class ApplicationInfoSuite extends FunSuite with Logging { assert(actualResult == expectedResult) } } + + val sparkRuntimeTestCases: Seq[(SparkRuntime.Value, String)] = Seq( + SparkRuntime.SPARK -> s"$qualLogDir/nds_q86_test", + SparkRuntime.SPARK_RAPIDS -> s"$logDir/nds_q66_gpu.zstd", + SparkRuntime.PHOTON-> s"$qualLogDir/nds_q88_photon_db_13_3.zstd" + ) + + sparkRuntimeTestCases.foreach { case (expectedSparkRuntime, eventLog) => + test(s"test spark runtime property for ${expectedSparkRuntime.toString} eventlog") { + val apps = ToolTestUtils.processProfileApps(Array(eventLog), sparkSession) + assert(apps.size == 1) + assert(apps.head.sparkRuntime == expectedSparkRuntime) + } + } }