Skip to content

Commit

Permalink
Improve AQE support by capturing SQLPlan versions (#1354)
Browse files Browse the repository at this point in the history
* Improve AQE support by capturing SQLPlan versions
* define an optimized plan that does not track entire planInfos
* capture V2 datasources

---------

Signed-off-by: Ahmed Hussein <[email protected]>
  • Loading branch information
amahussein authored Sep 23, 2024
1 parent 92911a4 commit 14a4213
Show file tree
Hide file tree
Showing 17 changed files with 608 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@ package com.nvidia.spark.rapids.tool.analysis
import scala.collection.mutable.{AbstractSet, ArrayBuffer, HashMap, LinkedHashSet}

import com.nvidia.spark.rapids.tool.planparser.SQLPlanParser
import com.nvidia.spark.rapids.tool.profiling.{AccumProfileResults, DataSourceCase, SQLAccumProfileResults, SQLMetricInfoCase, SQLStageInfoProfileResult, UnsupportedSQLPlan, WholeStageCodeGenResults}
import com.nvidia.spark.rapids.tool.profiling.{AccumProfileResults, SQLAccumProfileResults, SQLMetricInfoCase, SQLStageInfoProfileResult, UnsupportedSQLPlan, WholeStageCodeGenResults}
import com.nvidia.spark.rapids.tool.qualification.QualSQLPlanAnalyzer

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.{SparkPlanGraph, SparkPlanGraphCluster, SparkPlanGraphNode}
import org.apache.spark.sql.rapids.tool.{AppBase, RDDCheckHelper, SqlPlanInfoGraphBuffer, SqlPlanInfoGraphEntry}
import org.apache.spark.sql.rapids.tool.profiling.ApplicationInfo
import org.apache.spark.sql.rapids.tool.qualification.QualificationAppInfo
import org.apache.spark.sql.rapids.tool.store.DataSourceRecord
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph

/**
Expand Down Expand Up @@ -99,7 +100,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
// as RDD or DS.
protected case class SQLPlanVisitorContext(
sqlPIGEntry: SqlPlanInfoGraphEntry,
sqlDataSources: ArrayBuffer[DataSourceCase] = ArrayBuffer[DataSourceCase](),
sqlDataSources: ArrayBuffer[DataSourceRecord] = ArrayBuffer[DataSourceRecord](),
potentialProblems: LinkedHashSet[String] = LinkedHashSet[String](),
var sqlIsDsOrRDD: Boolean = false)

Expand Down Expand Up @@ -253,7 +254,7 @@ class AppSQLPlanAnalyzer(app: AppBase, appIndex: Int) extends AppAnalysisBase(ap
val nodeIds = sqlPlanNodeIdToStageIds.filter { case (_, v) =>
v.contains(sModel.stageInfo.stageId)
}.keys.toSeq
val nodeNames = app.sqlPlans.get(j.sqlID.get).map { planInfo =>
val nodeNames = app.sqlManager.applyToPlanInfo(j.sqlID.get) { planInfo =>
val nodes = ToolsPlanGraph(planInfo).allNodes
val validNodes = nodes.filter { n =>
nodeIds.contains((j.sqlID.get, n.id))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,11 @@

package com.nvidia.spark.rapids.tool.analysis

import scala.collection.mutable.Map

import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.SqlPlanInfoGraphEntry
import org.apache.spark.sql.rapids.tool.util.ToolsPlanGraph


// Class defines the SQLPlan context by implementations that walk through the SQLPlanInfo
class SQLPlanInfoContext(sqlPIGEntry: SqlPlanInfoGraphEntry) {
def getSQLPIGEntry: SqlPlanInfoGraphEntry = sqlPIGEntry
Expand Down Expand Up @@ -66,7 +63,7 @@ trait SparkSQLPlanInfoVisitor[R <: SQLPlanInfoContext] {
}

// Walks through all the SQLPlans in the given map
def walkPlans(plans: Map[Long, SparkPlanInfo]): Unit = {
def walkPlans(plans: collection.immutable.Map[Long, SparkPlanInfo]): Unit = {
for ((sqlId, planInfo) <- plans) {
val planCtxt = createPlanCtxt(sqlId, planInfo)
walkPlan(planCtxt)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ object ReadParser extends Logging {
val METAFIELD_TAG_DATA_FILTERS = "DataFilters"
val METAFIELD_TAG_PUSHED_FILTERS = "PushedFilters"
val METAFIELD_TAG_PARTITION_FILTERS = "PartitionFilters"
val METAFIELD_TAG_READ_SCHEMA = "ReadSchema"
val METAFIELD_TAG_FORMAT = "Format"
val METAFIELD_TAG_LOCATION = "Location"

val UNKNOWN_METAFIELD: String = "unknown"
val DEFAULT_METAFIELD_MAP: Map[String, String] = collection.immutable.Map(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ object CollectInformation extends Logging {
val planFileWriter = new ToolTextFileWriter(s"$outputDir/${app.appId}",
"planDescriptions.log", "SQL Plan")
try {
for ((sqlID, planDesc) <- app.physicalPlanDescription.toSeq.sortBy(_._1)) {
for ((sqlID, planDesc) <- app.sqlManager.getPhysicalPlans) {
planFileWriter.write("\n=============================\n")
planFileWriter.write(s"Plan for SQL ID : $sqlID")
planFileWriter.write("\n=============================\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,14 @@ object GenerateDot {
list += row(1) -> row(2)
}

val sqlPlansMap = app.sqlPlans.map { case (sqlId, sparkPlanInfo) =>
sqlId -> ((sparkPlanInfo, app.physicalPlanDescription(sqlId)))
}
for ((sqlID, (planInfo, physicalPlan)) <- sqlPlansMap) {
for (sqlPlan <- app.sqlManager.sqlPlans.values) {
val dotFileWriter = new ToolTextFileWriter(outputDirectory,
s"query-$sqlID.dot", "Dot file")
s"query-${sqlPlan.id}.dot", "Dot file")
try {
val metrics = sqlIdToMaxMetric.getOrElse(sqlID, Seq.empty).toMap
val metrics = sqlIdToMaxMetric.getOrElse(sqlPlan.id, Seq.empty).toMap
GenerateDot.writeDotGraph(
QueryPlanWithMetrics(SparkPlanInfoWithStage(planInfo, accumIdToStageId), metrics),
physicalPlan, stageIdToStageMetrics, dotFileWriter, sqlID, app.appId)
QueryPlanWithMetrics(SparkPlanInfoWithStage(sqlPlan.planInfo, accumIdToStageId), metrics),
sqlPlan.physicalPlanDesc, stageIdToStageMetrics, dotFileWriter, sqlPlan.id, app.appId)
} finally {
dotFileWriter.close()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ class SQLPlanClassifier(app: ApplicationInfo)

override def visitNode(sqlPlanCtxt: SQLPlanClassifierCtxt, node: ui.SparkPlanGraphNode): Unit = {
// Check if the node is a delta metadata operation
val isDeltaLog = DeltaLakeHelper.isDeltaOpNode(
sqlPlanCtxt.sqlPIGEntry, app.physicalPlanDescription(sqlPlanCtxt.getSQLPIGEntry.sqlID), node)
val isDeltaLog = DeltaLakeHelper.isDeltaOpNode(sqlPlanCtxt.sqlPIGEntry,
app.sqlManager.getPhysicalPlanById(sqlPlanCtxt.getSQLPIGEntry.sqlID).get, node)
if (isDeltaLog) {
// if it is a Delta operation, add it to the list of Delta operations nodes
sqlPlanCtxt.deltaOpsNode += node.id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,28 @@ case class RapidsJarProfileResult(appIndex: Int, jar: String) extends ProfileRe
}
}

case class DataSourceProfileResult(appIndex: Int, sqlID: Long, nodeId: Long,
case class DataSourceProfileResult(appIndex: Int, sqlID: Long, version: Int, nodeId: Long,
format: String, buffer_time: Long, scan_time: Long, data_size: Long,
decode_time: Long, location: String, pushedFilters: String, schema: String,
dataFilters: String, partitionFilters: String)
dataFilters: String, partitionFilters: String, fromFinalPlan: Boolean)
extends ProfileResult {
override val outputHeaders =
Seq("appIndex", "sqlID", "nodeId", "format", "buffer_time", "scan_time", "data_size",
"decode_time", "location", "pushedFilters", "schema", "data_filters", "partition_filters")
Seq("appIndex", "sqlID", "sql_plan_version", "nodeId", "format", "buffer_time", "scan_time",
"data_size", "decode_time", "location", "pushedFilters", "schema", "data_filters",
"partition_filters", "from_final_plan")

override def convertToSeq: Seq[String] = {
Seq(appIndex.toString, sqlID.toString, nodeId.toString, format, buffer_time.toString,
scan_time.toString, data_size.toString, decode_time.toString,
location, pushedFilters, schema, dataFilters, partitionFilters)
Seq(appIndex.toString, sqlID.toString, version.toString, nodeId.toString, format,
buffer_time.toString, scan_time.toString, data_size.toString, decode_time.toString,
location, pushedFilters, schema, dataFilters, partitionFilters, fromFinalPlan.toString)
}
override def convertToCSVSeq: Seq[String] = {
Seq(appIndex.toString, sqlID.toString, nodeId.toString, StringUtils.reformatCSVString(format),
buffer_time.toString, scan_time.toString, data_size.toString, decode_time.toString,
StringUtils.reformatCSVString(location), StringUtils.reformatCSVString(pushedFilters),
StringUtils.reformatCSVString(schema),
StringUtils.reformatCSVString(dataFilters),
StringUtils.reformatCSVString(partitionFilters))
Seq(appIndex.toString, sqlID.toString, version.toString, nodeId.toString,
StringUtils.reformatCSVString(format), buffer_time.toString, scan_time.toString,
data_size.toString, decode_time.toString, StringUtils.reformatCSVString(location),
StringUtils.reformatCSVString(pushedFilters), StringUtils.reformatCSVString(schema),
StringUtils.reformatCSVString(dataFilters), StringUtils.reformatCSVString(partitionFilters),
fromFinalPlan.toString)
}
}

Expand Down Expand Up @@ -364,16 +365,6 @@ case class DriverAccumCase(
case class UnsupportedSQLPlan(sqlID: Long, nodeID: Long, nodeName: String,
nodeDesc: String, reason: String)

case class DataSourceCase(
sqlID: Long,
nodeId: Long,
format: String,
location: String,
pushedFilters: String,
schema: String,
dataFilters: String,
partitionFilters: String)

case class FailedTaskProfileResults(appIndex: Int, stageId: Int, stageAttemptId: Int,
taskId: Long, taskAttemptId: Int, endReason: String) extends ProfileResult {
override val outputHeaders = Seq("appIndex", "stageId", "stageAttemptId", "taskId",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,31 @@ trait AppDataSourceViewTrait extends ViewableTrait[DataSourceProfileResult] {
|| sqlAccum.name.contains(IoMetrics.DECODE_TIME_LABEL)
|| sqlAccum.name.equals(IoMetrics.DATA_SIZE_LABEL))

app.dataSourceInfo.map { ds =>
val dsFromLastPlan = app.dataSourceInfo.map { ds =>
val sqlIdtoDs = dataSourceMetrics.filter(
sqlAccum => sqlAccum.sqlID == ds.sqlID && sqlAccum.nodeID == ds.nodeId)
val ioMetrics = if (sqlIdtoDs.nonEmpty) {
getIoMetrics(sqlIdtoDs)
} else {
IoMetrics.EMPTY_IO_METRICS
}
DataSourceProfileResult(index, ds.sqlID, ds.nodeId,
DataSourceProfileResult(index, ds.sqlID, ds.version, ds.nodeId,
ds.format, ioMetrics.bufferTime, ioMetrics.scanTime, ioMetrics.dataSize,
ioMetrics.decodeTime, ds.location, ds.pushedFilters, ds.schema, ds.dataFilters,
ds.partitionFilters)
ds.partitionFilters, ds.isFromFinalPlan)
}
val dsFromOrigPlans = app.sqlManager.getDataSourcesFromOrigPlans.map { ds =>
DataSourceProfileResult(index, ds.sqlID, ds.version, ds.nodeId, ds.format,
IoMetrics.EMPTY_IO_METRICS.bufferTime, IoMetrics.EMPTY_IO_METRICS.scanTime,
IoMetrics.EMPTY_IO_METRICS.dataSize, IoMetrics.EMPTY_IO_METRICS.decodeTime,
ds.location, ds.pushedFilters, ds.schema, ds.dataFilters, ds.partitionFilters,
ds.isFromFinalPlan)
}
dsFromLastPlan ++ dsFromOrigPlans
}

override def sortView(rows: Seq[DataSourceProfileResult]): Seq[DataSourceProfileResult] = {
rows.sortBy(cols => (cols.appIndex, cols.sqlID, cols.location, cols.schema))
rows.sortBy(cols => (cols.appIndex, cols.sqlID, cols.version, cols.location, cols.schema))
}
}

Expand Down
34 changes: 19 additions & 15 deletions core/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ import java.io.InputStream
import java.util.zip.GZIPInputStream

import scala.collection.immutable
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map, SortedMap}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashSet, Map}

import com.nvidia.spark.rapids.SparkRapidsBuildInfoEvent
import com.nvidia.spark.rapids.tool.{DatabricksEventLog, DatabricksRollingEventLogFilesFileReader, EventLogInfo}
import com.nvidia.spark.rapids.tool.planparser.{HiveParseHelper, ReadParser}
import com.nvidia.spark.rapids.tool.planparser.HiveParseHelper.isHiveTableScanNode
import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DataSourceCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase}
import com.nvidia.spark.rapids.tool.profiling.{BlockManagerRemovedCase, DriverAccumCase, JobInfoClass, ResourceProfileInfoCase, SQLExecutionInfoClass, SQLPlanMetricsCase}
import com.nvidia.spark.rapids.tool.qualification.AppSubscriber
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
Expand All @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.scheduler.{SparkListenerEvent, StageInfo}
import org.apache.spark.sql.execution.SparkPlanInfo
import org.apache.spark.sql.execution.ui.SparkPlanGraphNode
import org.apache.spark.sql.rapids.tool.store.{AccumManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.store.{AccumManager, DataSourceRecord, SQLPlanModelManager, StageModel, StageModelManager, TaskModelManager}
import org.apache.spark.sql.rapids.tool.util.{EventUtils, RapidsToolsConfUtil, ToolsPlanGraph, UTF8Source}
import org.apache.spark.util.Utils

Expand All @@ -63,12 +63,14 @@ abstract class AppBase(
var blockManagersRemoved: ArrayBuffer[BlockManagerRemovedCase] =
ArrayBuffer[BlockManagerRemovedCase]()
// The data source information
val dataSourceInfo: ArrayBuffer[DataSourceCase] = ArrayBuffer[DataSourceCase]()
val dataSourceInfo: ArrayBuffer[DataSourceRecord] = ArrayBuffer[DataSourceRecord]()

// jobId to job info
val jobIdToInfo = new HashMap[Int, JobInfoClass]()
val jobIdToSqlID: HashMap[Int, Long] = HashMap.empty[Int, Long]

lazy val sqlManager = new SQLPlanModelManager()

// SQL containing any Dataset operation or RDD to DataSet/DataFrame operation
val sqlIDToDataSetOrRDDCase: HashSet[Long] = HashSet[Long]()
// Map (sqlID <-> String(problematic issues))
Expand All @@ -78,9 +80,6 @@ abstract class AppBase(
// sqlId to sql info
val sqlIdToInfo = new HashMap[Long, SQLExecutionInfoClass]()
val sqlIdToStages = new HashMap[Long, ArrayBuffer[Int]]()
// sqlPlans stores HashMap (sqlID <-> SparkPlanInfo)
// SortedMap is used to keep the order of the sqlPlans since AQEs can overrides the existing ones
var sqlPlans: Map[Long, SparkPlanInfo] = SortedMap[Long, SparkPlanInfo]()
var sqlPlanMetricsAdaptive: ArrayBuffer[SQLPlanMetricsCase] = ArrayBuffer[SQLPlanMetricsCase]()

// accum id to task stage accum info
Expand All @@ -97,6 +96,8 @@ abstract class AppBase(
var sparkRapidsBuildInfo: SparkRapidsBuildInfoEvent = SparkRapidsBuildInfoEvent(immutable.Map(),
immutable.Map(), immutable.Map(), immutable.Map())

def sqlPlans: immutable.Map[Long, SparkPlanInfo] = sqlManager.getPlanInfos

// Returns the String value of the eventlog or empty if it is not defined. Note that the eventlog
// won't be defined for running applications
def getEventLogPath: String = {
Expand Down Expand Up @@ -225,7 +226,7 @@ abstract class AppBase(
sqlIDToDataSetOrRDDCase.remove(sqlID)
sqlIDtoProblematic.remove(sqlID)
sqlIdToInfo.remove(sqlID)
sqlPlans.remove(sqlID)
sqlManager.remove(sqlID)
val dsToRemove = dataSourceInfo.filter(_.sqlID == sqlID)
dsToRemove.foreach(dataSourceInfo -= _)

Expand Down Expand Up @@ -336,11 +337,11 @@ abstract class AppBase(

// The ReadSchema metadata is only in the eventlog for DataSource V1 readers
def checkMetadataForReadSchema(
sqlPlanInfoGraph: SqlPlanInfoGraphEntry): ArrayBuffer[DataSourceCase] = {
sqlPlanInfoGraph: SqlPlanInfoGraphEntry): ArrayBuffer[DataSourceRecord] = {
// check if planInfo has ReadSchema
val allMetaWithSchema = AppBase.getPlanMetaWithSchema(sqlPlanInfoGraph.planInfo)
val allNodes = sqlPlanInfoGraph.sparkPlanGraph.allNodes
val results = ArrayBuffer[DataSourceCase]()
val results = ArrayBuffer[DataSourceRecord]()

allMetaWithSchema.foreach { plan =>
val meta = plan.metadata
Expand All @@ -355,8 +356,9 @@ abstract class AppBase(
// add it to the dataSourceInfo
// Processing Photon eventlogs issue: https://github.com/NVIDIA/spark-rapids-tools/issues/251
if (scanNode.nonEmpty) {
results += DataSourceCase(
results += DataSourceRecord(
sqlPlanInfoGraph.sqlID,
sqlManager.getPlanById(sqlPlanInfoGraph.sqlID).get.plan.version,
scanNode.head.id,
ReadParser.extractTagFromV1ReadMeta("Format", meta),
ReadParser.extractTagFromV1ReadMeta("Location", meta),
Expand All @@ -375,8 +377,9 @@ abstract class AppBase(
val sqlGraph = ToolsPlanGraph(hiveReadPlan)
val hiveScanNode = sqlGraph.allNodes.head
val scanHiveMeta = HiveParseHelper.parseReadNode(hiveScanNode)
results += DataSourceCase(
results += DataSourceRecord(
sqlPlanInfoGraph.sqlID,
sqlManager.getPlanById(sqlPlanInfoGraph.sqlID).get.plan.version,
hiveScanNode.id,
scanHiveMeta.format,
scanHiveMeta.location,
Expand All @@ -394,11 +397,12 @@ abstract class AppBase(
// This will find scans for DataSource V2, if the schema is very large it
// will likely be incomplete and have ... at the end.
def checkGraphNodeForReads(
sqlID: Long, node: SparkPlanGraphNode): Option[DataSourceCase] = {
sqlID: Long, node: SparkPlanGraphNode): Option[DataSourceRecord] = {
if (ReadParser.isDataSourceV2Node(node)) {
val res = ReadParser.parseReadNode(node)
val dsCase = DataSourceCase(
val dsCase = DataSourceRecord(
sqlID,
sqlManager.getPlanById(sqlID).get.plan.version,
node.id,
res.format,
res.location,
Expand Down Expand Up @@ -547,7 +551,7 @@ object AppBase {
(complexTypes.filter(_.nonEmpty), nestedComplexTypes.filter(_.nonEmpty))
}

private def trimSchema(str: String): String = {
def trimSchema(str: String): String = {
val index = str.lastIndexOf(",")
if (index != -1 && str.contains("...")) {
str.substring(0, index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ abstract class EventProcessorBase[T <: AppBase](app: T) extends SparkListener wi
hasDatasetOrRDD = false
)
app.sqlIdToInfo.put(event.executionId, sqlExecution)
app.sqlPlans += (event.executionId -> event.sparkPlanInfo)
app.sqlManager.addNewExecution(event.executionId, event.sparkPlanInfo,
event.physicalPlanDescription)
}

def doSparkListenerSQLExecutionEnd(
Expand Down Expand Up @@ -183,7 +184,8 @@ abstract class EventProcessorBase[T <: AppBase](app: T) extends SparkListener wi
app: T,
event: SparkListenerSQLAdaptiveExecutionUpdate): Unit = {
// AQE plan can override the ones got from SparkListenerSQLExecutionStart
app.sqlPlans += (event.executionId -> event.sparkPlanInfo)
app.sqlManager.addAQE(event.executionId, event.sparkPlanInfo,
event.physicalPlanDescription)
}

def doSparkListenerSQLAdaptiveSQLMetricUpdates(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.apache.spark.sql.rapids.tool.profiling

import scala.collection.{mutable, Map}
import scala.collection.Map

import com.nvidia.spark.rapids.tool.EventLogInfo
import com.nvidia.spark.rapids.tool.analysis.AppSQLPlanAnalyzer
Expand Down Expand Up @@ -187,9 +187,6 @@ class ApplicationInfo(
val index: Int)
extends AppBase(Some(eLogInfo), Some(hadoopConf)) with Logging {

// physicalPlanDescription stores HashMap (sqlID <-> physicalPlanDescription)
var physicalPlanDescription: mutable.HashMap[Long, String] = mutable.HashMap.empty[Long, String]

private lazy val eventProcessor = new EventsProcessor(this)

// Process all events
Expand Down
Loading

0 comments on commit 14a4213

Please sign in to comment.