From e1db8ded218455f49ed82085dc6a666a240803ed Mon Sep 17 00:00:00 2001 From: Shri Saran Raj N Date: Sun, 26 Jan 2025 16:58:37 +0530 Subject: [PATCH 1/2] Add FlintJob to support queries in warmpool mode Signed-off-by: Shri Saran Raj N --- .../flint/core/metrics/MetricConstants.java | 25 ++++ .../sql/flint/config/FlintSparkConf.scala | 13 ++ .../apache/spark/sql/FlintJobITSuite.scala | 6 +- .../scala/org/apache/spark/sql/FlintJob.scala | 100 ++++++++----- .../apache/spark/sql/FlintJobExecutor.scala | 29 ++++ .../org/apache/spark/sql/FlintREPL.scala | 30 +--- .../org/apache/spark/sql/JobOperator.scala | 140 +++++++++++++----- .../org/apache/spark/sql/WarmpoolJob.scala | 122 +++++++++++++++ .../apache/spark/sql/util/WarmpoolTest.scala | 108 ++++++++++++++ 9 files changed, 472 insertions(+), 101 deletions(-) create mode 100644 spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala create mode 100644 spark-sql-application/src/test/scala/org/apache/spark/sql/util/WarmpoolTest.scala diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java index 79e70b8c2..6e2ff42d6 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/MetricConstants.java @@ -95,6 +95,11 @@ public final class MetricConstants { */ public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write"; + /** + * Prefix for metrics related to interactive queries + */ + public static final String STATEMENT = "statement"; + /** * Metric name for counting the number of statements currently running. */ @@ -135,11 +140,31 @@ public final class MetricConstants { */ public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count"; + /** + * Metric for tracking the count of jobs failed during query execution + */ + public static final String QUERY_EXECUTION_FAILED_METRIC = "execution.failed.count"; + + /** + * Metric for tracking the count of jobs failed during query result write + */ + public static final String RESULT_WRITER_FAILED_METRIC = "writer.failed.count"; + /** * Metric for tracking the latency of query execution (start to complete query execution) excluding result write. */ public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime"; + /** + * Metric for tracking the latency of query result write only (excluding query execution) + */ + public static final String QUERY_RESULT_WRITER_TIME_METRIC = "result.writer.processingTime"; + + /** + * Metric for tracking the latency of query total execution including result write. + */ + public static final String QUERY_TOTAL_TIME_METRIC = "query.total.processingTime"; + /** * Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX) */ diff --git a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala index 364a8a1de..730603c6a 100644 --- a/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala +++ b/flint-spark-integration/src/main/scala/org/apache/spark/sql/flint/config/FlintSparkConf.scala @@ -171,6 +171,12 @@ object FlintSparkConf { .doc("Enable external scheduler for index refresh") .createWithDefault("false") + val WARMPOOL_ENABLED = + FlintConfig("spark.flint.job.warmpoolEnabled") + .createWithDefault("false") + + val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional() + val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD = FlintConfig("spark.flint.job.externalScheduler.interval") .doc("Interval threshold in minutes for external scheduler to trigger index refresh") @@ -246,6 +252,10 @@ object FlintSparkConf { FlintConfig(s"spark.flint.job.requestIndex") .doc("Request index") .createOptional() + val RESULT_INDEX = + FlintConfig(s"spark.flint.job.resultIndex") + .doc("Result index") + .createOptional() val EXCLUDE_JOB_IDS = FlintConfig(s"spark.flint.deployment.excludeJobs") .doc("Exclude job ids") @@ -271,6 +281,9 @@ object FlintSparkConf { val CUSTOM_QUERY_RESULT_WRITER = FlintConfig("spark.flint.job.customQueryResultWriter") .createOptional() + val TERMINATE_JVM = FlintConfig("spark.flint.terminateJVM") + .doc("Indicates whether the JVM should be terminated after query execution") + .createWithDefault("true") } /** diff --git a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala index 81bf60f5e..6b7a71f5d 100644 --- a/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala +++ b/integ-test/src/integration/scala/org/apache/spark/sql/FlintJobITSuite.scala @@ -39,6 +39,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { val appId = "00feq82b752mbt0p" val dataSourceName = "my_glue1" val queryId = "testQueryId" + val requestIndex = "testRequestIndex" var osClient: OSClient = _ val threadLocalFuture = new ThreadLocal[Future[Unit]]() @@ -83,6 +84,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { def createJobOperator(query: String, jobRunId: String): JobOperator = { val streamingRunningCount = new AtomicInteger(0) + val statementRunningCount = new AtomicInteger(0) /* * Because we cannot test from FlintJob.main() for the reason below, we have to configure @@ -90,6 +92,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { */ spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName) spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING) + spark.conf.set(REQUEST_INDEX.key, requestIndex) val job = JobOperator( appId, @@ -100,7 +103,8 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest { dataSourceName, resultIndex, FlintJobType.STREAMING, - streamingRunningCount) + streamingRunningCount, + statementRunningCount) job.terminateJVM = false job } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala index 04609cf3d..5444c88d8 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJob.scala @@ -8,12 +8,15 @@ package org.apache.spark.sql import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} + import org.opensearch.flint.core.logging.CustomLogging import org.opensearch.flint.core.metrics.MetricConstants import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge import org.apache.spark.internal.Logging import org.apache.spark.sql.flint.config.FlintSparkConf +import org.apache.spark.util.ThreadUtils /** * Spark SQL Application entrypoint @@ -26,52 +29,71 @@ import org.apache.spark.sql.flint.config.FlintSparkConf * write sql query result to given opensearch index */ object FlintJob extends Logging with FlintJobExecutor { + private val streamingRunningCount = new AtomicInteger(0) + private val statementRunningCount = new AtomicInteger(0) + def main(args: Array[String]): Unit = { val (queryOption, resultIndexOption) = parseArgs(args) val conf = createSparkConf() - val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH) - CustomLogging.logInfo(s"""Job type is: ${jobType}""") - conf.set(FlintSparkConf.JOB_TYPE.key, jobType) - - val dataSource = conf.get("spark.flint.datasource.name", "") - val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) - if (query.isEmpty) { - logAndThrow(s"Query undefined for the ${jobType} job.") - } - val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") - - if (resultIndexOption.isEmpty) { - logAndThrow("resultIndex is not set") - } - // https://github.com/opensearch-project/opensearch-spark/issues/138 - /* - * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, - * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), - * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. - * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. - * Without this setup, Spark would not recognize names in the format `my_glue1.default`. - */ - conf.set("spark.sql.defaultCatalog", dataSource) - configDYNMaxExecutors(conf, jobType) - + val sparkSession = createSparkSession(conf) val applicationId = environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown") val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown") + val isWarmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean + logInfo(s"isWarmpoolEnabled: ${isWarmpoolEnabled}") + + if (!isWarmpoolEnabled) { + val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH) + CustomLogging.logInfo(s"""Job type is: ${jobType}""") + sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType) + + val dataSource = conf.get("spark.flint.datasource.name", "") + val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, ""))) + if (query.isEmpty) { + logAndThrow(s"Query undefined for the ${jobType} job.") + } + val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "") - val streamingRunningCount = new AtomicInteger(0) - val jobOperator = - JobOperator( - applicationId, - jobId, - createSparkSession(conf), - query, - queryId, - dataSource, - resultIndexOption.get, - jobType, - streamingRunningCount) - registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) - jobOperator.start() + if (resultIndexOption.isEmpty) { + logAndThrow("resultIndex is not set") + } + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + conf.set("spark.sql.defaultCatalog", dataSource) + configDYNMaxExecutors(conf, jobType) + + val jobOperator = + JobOperator( + applicationId, + jobId, + sparkSession, + query, + queryId, + dataSource, + resultIndexOption.get, + jobType, + streamingRunningCount, + statementRunningCount, + Map.empty) + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + jobOperator.start() + } else { + // Fetch and execute queries in warm pool mode + val warmpoolJob = + WarmpoolJob( + applicationId, + jobId, + sparkSession, + streamingRunningCount, + statementRunningCount) + warmpoolJob.start() + } } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala index ad26cf21a..8b1ddea40 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintJobExecutor.scala @@ -6,6 +6,7 @@ package org.apache.spark.sql import java.util.Locale +import java.util.concurrent.ThreadPoolExecutor import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException} import com.amazonaws.services.s3.model.AmazonS3Exception @@ -20,6 +21,7 @@ import play.api.libs.json._ import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.FlintREPL.instantiate import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.exception.UnrecoverableException @@ -566,4 +568,31 @@ trait FlintJobExecutor { } } } + + def instantiateQueryResultWriter( + spark: SparkSession, + commandContext: CommandContext): QueryResultWriter = { + instantiate( + new QueryResultWriterImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) + } + + def instantiateStatementExecutionManager( + commandContext: CommandContext): StatementExecutionManager = { + import commandContext._ + instantiate( + new StatementExecutionManagerImpl(commandContext), + spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), + spark, + sessionId) + } + + def instantiateSessionManager( + spark: SparkSession, + resultIndexOption: Option[String]): SessionManager = { + instantiate( + new SessionManagerImpl(spark, resultIndexOption), + spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""), + resultIndexOption.getOrElse("")) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala index 6d7dcc0e7..b33a041dd 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/FlintREPL.scala @@ -106,7 +106,8 @@ object FlintREPL extends Logging with FlintJobExecutor { dataSource, resultIndexOption.get, jobType, - streamingRunningCount) + streamingRunningCount, + statementRunningCount) registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) jobOperator.start() } else { @@ -1021,33 +1022,6 @@ object FlintREPL extends Logging with FlintJobExecutor { } } - private def instantiateSessionManager( - spark: SparkSession, - resultIndexOption: Option[String]): SessionManager = { - instantiate( - new SessionManagerImpl(spark, resultIndexOption), - spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""), - resultIndexOption.getOrElse("")) - } - - private def instantiateStatementExecutionManager( - commandContext: CommandContext): StatementExecutionManager = { - import commandContext._ - instantiate( - new StatementExecutionManagerImpl(commandContext), - spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""), - spark, - sessionId) - } - - private def instantiateQueryResultWriter( - spark: SparkSession, - commandContext: CommandContext): QueryResultWriter = { - instantiate( - new QueryResultWriterImpl(commandContext), - spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, "")) - } - private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = { logInfo("Session Success") stopTimer(sessionTimerContext) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index 27b0be84f..dfd5e5759 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -15,7 +15,6 @@ import scala.util.{Failure, Success, Try} import org.opensearch.flint.common.model.FlintStatement import org.opensearch.flint.common.scheduler.model.LangType import org.opensearch.flint.core.metrics.{MetricConstants, MetricsSparkListener, MetricsUtil} -import org.opensearch.flint.core.metrics.MetricsUtil.incrementCounter import org.opensearch.flint.spark.FlintSpark import org.apache.spark.internal.Logging @@ -32,26 +31,40 @@ case class JobOperator( dataSource: String, resultIndex: String, jobType: String, - streamingRunningCount: AtomicInteger) + streamingRunningCount: AtomicInteger, + statementRunningCount: AtomicInteger, + statementContext: Map[String, Any] = Map.empty[String, Any]) extends Logging with FlintJobExecutor { // JVM shutdown hook sys.addShutdownHook(stop()) + val isStreamingOrBatch = + jobType.equalsIgnoreCase(FlintJobType.STREAMING) || jobType.equalsIgnoreCase( + FlintJobType.BATCH) + val isWarmpoolEnabled = + sparkSession.conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean + val segmentName = getSegmentName(sparkSession) def start(): Unit = { val threadPool = ThreadUtils.newDaemonFixedThreadPool(1, "check-create-index") implicit val executionContext = ExecutionContext.fromExecutor(threadPool) - var dataToWrite: Option[DataFrame] = None val startTime = System.currentTimeMillis() - streamingRunningCount.incrementAndGet() + if (isStreamingOrBatch) { + streamingRunningCount.incrementAndGet() + } else { + statementRunningCount.incrementAndGet() + } // osClient needs spark session to be created first to get FlintOptions initialized. // Otherwise, we will have connection exception from EMR-S to OS. val osClient = new OSClient(FlintSparkConf().flintOptions()) + // QueryResultWriter depends on sessionManager to fetch the sessionContext + val sessionManager = instantiateSessionManager(sparkSession, Some(resultIndex)) + // TODO: Update FlintJob to Support All Query Types. Track on https://github.com/opensearch-project/opensearch-spark/issues/633 val commandContext = CommandContext( applicationId, @@ -60,7 +73,7 @@ case class JobOperator( dataSource, jobType, "", // FlintJob doesn't have sessionId - null, // FlintJob doesn't have SessionManager + sessionManager, Duration.Inf, // FlintJob doesn't have queryExecutionTimeout -1, // FlintJob doesn't have inactivityLimitMillis -1, // FlintJob doesn't have queryWaitTimeMillis @@ -80,7 +93,9 @@ case class JobOperator( "", queryId, LangType.SQL, - currentTimeProvider.currentEpochMillis()) + currentTimeProvider.currentEpochMillis(), + Option.empty, + statementContext) try { val futurePrepareQueryExecution = Future { @@ -119,6 +134,7 @@ case class JobOperator( query, "", startTime)) + incrementCounter(MetricConstants.QUERY_EXECUTION_FAILED_METRIC) case t: Throwable => val error = processQueryException(t) dataToWrite = Some( @@ -133,18 +149,30 @@ case class JobOperator( query, "", startTime)) + incrementCounter(MetricConstants.QUERY_EXECUTION_FAILED_METRIC) } finally { - emitQueryExecutionTimeMetric(startTime) + emitTimerMetric(MetricConstants.QUERY_EXECUTION_TIME_METRIC, startTime) readWriteBytesSparkListener.emitMetrics() sparkSession.sparkContext.removeSparkListener(readWriteBytesSparkListener) + val resultWriterStartTime = System.currentTimeMillis() try { - dataToWrite.foreach(df => writeDataFrameToOpensearch(df, resultIndex, osClient)) + dataToWrite.foreach(df => { + if (isStreamingOrBatch) { + writeDataFrameToOpensearch(df, resultIndex, osClient) + } else { + val queryResultWriter = instantiateQueryResultWriter(sparkSession, commandContext) + queryResultWriter.writeDataFrame(df, statement) + } + }) } catch { case t: Throwable => + incrementCounter(MetricConstants.RESULT_WRITER_FAILED_METRIC) throwableHandler.recordThrowable( - s"Failed to write to result index. originalError='${throwableHandler.error}'", + s"Failed to write to result. originalError='${t.getMessage}'", t) + } finally { + emitTimerMetric(MetricConstants.QUERY_RESULT_WRITER_TIME_METRIC, resultWriterStartTime) } if (throwableHandler.hasException) statement.fail() else statement.complete() statement.error = Some(throwableHandler.error) @@ -157,30 +185,29 @@ case class JobOperator( s"Failed to update statement. originalError='${throwableHandler.error}'", t) } - + emitTimerMetric(MetricConstants.QUERY_TOTAL_TIME_METRIC, startTime) cleanUpResources(threadPool) } } def cleanUpResources(threadPool: ThreadPoolExecutor): Unit = { - val isStreaming = jobType.equalsIgnoreCase(FlintJobType.STREAMING) try { - // Wait for streaming job complete if no error - if (!throwableHandler.hasException && isStreaming) { + // Wait for job complete if no error + if (!throwableHandler.hasException && isStreamingOrBatch) { // Clean Spark shuffle data after each microBatch. sparkSession.streams.addListener(new ShuffleCleaner(sparkSession)) // Await index monitor before the main thread terminates new FlintSpark(sparkSession).flintIndexMonitor.awaitMonitor() } else { logInfo(s""" - | Skip streaming job await due to conditions not met: - | - exceptionThrown: ${throwableHandler.hasException} - | - streaming: $isStreaming - | - activeStreams: ${sparkSession.streams.active.mkString(",")} - |""".stripMargin) + | Skip job await due to conditions not met: + | - exceptionThrown: ${throwableHandler.hasException} + | - streaming: $isStreamingOrBatch + | - activeStreams: ${sparkSession.streams.active.mkString(",")} + |""".stripMargin) } } catch { - case e: Exception => logError("streaming job failed", e) + case e: Exception => logError("job failed", e) } try { @@ -190,7 +217,7 @@ case class JobOperator( } catch { case e: Exception => logError("Fail to close threadpool", e) } - recordStreamingCompletionStatus(throwableHandler.hasException) + recordCompletionStatus(throwableHandler.hasException) // Check for non-daemon threads that may prevent the driver from shutting down. // Non-daemon threads other than the main thread indicate that the driver is still processing tasks, @@ -205,11 +232,9 @@ case class JobOperator( } } - private def emitQueryExecutionTimeMetric(startTime: Long): Unit = { + private def emitTimerMetric(metricName: String, startTime: Long): Unit = { MetricsUtil - .addHistoricGauge( - MetricConstants.QUERY_EXECUTION_TIME_METRIC, - System.currentTimeMillis() - startTime) + .addHistoricGauge(resolveMetricName(metricName), System.currentTimeMillis() - startTime) } def stop(): Unit = { @@ -229,22 +254,34 @@ case class JobOperator( } /** - * Records the completion of a streaming job by updating the appropriate metrics. This method - * decrements the running metric for streaming jobs and increments either the success or failure - * metric based on whether an exception was thrown. + * Records the completion of a job by updating the appropriate metrics. This method decrements + * the running metric for jobs and increments either the success or failure metric based on + * whether an exception was thrown. * * @param exceptionThrown - * Indicates whether an exception was thrown during the streaming job execution. + * Indicates whether an exception was thrown during the job execution. */ - private def recordStreamingCompletionStatus(exceptionThrown: Boolean): Unit = { - // Decrement the metric for running streaming jobs as the job is now completing. + private def recordCompletionStatus(exceptionThrown: Boolean): Unit = { + // Decrement the metric for running jobs as the job is now completing. if (streamingRunningCount.get() > 0) { streamingRunningCount.decrementAndGet() + } else if (statementRunningCount.get() > 0) { + statementRunningCount.decrementAndGet() } - exceptionThrown match { - case true => incrementCounter(MetricConstants.STREAMING_FAILED_METRIC) - case false => incrementCounter(MetricConstants.STREAMING_SUCCESS_METRIC) + val metric = { + (exceptionThrown, isStreamingOrBatch) match { + case (true, true) => MetricConstants.STREAMING_FAILED_METRIC + case (true, false) => MetricConstants.STATEMENT_FAILED_METRIC + case (false, true) => MetricConstants.STREAMING_SUCCESS_METRIC + case (false, false) => MetricConstants.STATEMENT_SUCCESS_METRIC + } + } + + if (isWarmpoolEnabled) { + MetricsUtil.incrementCounter(String.format("%s.%s", segmentName, metric)); + } else { + MetricsUtil.incrementCounter(metric); } } @@ -259,4 +296,41 @@ case class JobOperator( spark, sessionId) } + + /** + * Returns a segment name formatted with the maximum executor count. For example, if the max + * executor count is 1, the return value will be "1e". + * + * @param spark + * The Spark session. + * @return + * A string in the format "e", e.g., "1e", "2e". + */ + private def getSegmentName(spark: SparkSession): String = { + val maxExecutorsCount = spark.conf.get(FlintSparkConf.MAX_EXECUTORS_COUNT.key, "unknown") + String.format("%se", maxExecutorsCount) + } + + /** + * Resolves the full metric name based on the job type and warm pool status. If the warm pool is + * enabled, the metric name is prefixed with the segment name and job type (STREAMING or + * STATEMENT). + * + * @param metricName + * The base metric name to resolve. + * @return + * The resolved metric name, e.g., "1e.streaming.success.count" if warm pool is enabled. + */ + private def resolveMetricName(metricName: String): String = { + if (isWarmpoolEnabled) { + val jobType = if (isStreamingOrBatch) FlintJobType.STREAMING else MetricConstants.STATEMENT + val newMetricName = String.format("%s.%s.%s", segmentName, jobType, metricName) + return newMetricName + } + metricName + } + + private def incrementCounter(metricName: String): Unit = { + MetricsUtil.incrementCounter(resolveMetricName(metricName)) + } } diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala new file mode 100644 index 000000000..23648a02c --- /dev/null +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/WarmpoolJob.scala @@ -0,0 +1,122 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicInteger + +import scala.concurrent.duration.Duration + +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.core.metrics.MetricConstants +import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.flint.config.FlintSparkConf + +/** + * This class executes Spark jobs in "warm pool" mode, repeatedly calling the client to fetch + * query details (job type, data source, configurations). The job is created without any + * query-specific configurations, and the client sets the Spark configurations at runtime during + * each iteration + */ +case class WarmpoolJob( + applicationId: String, + jobId: String, + spark: SparkSession, + streamingRunningCount: AtomicInteger, + statementRunningCount: AtomicInteger) + extends Logging + with FlintJobExecutor { + + def start(): Unit = { + val commandContext = CommandContext( + applicationId, + jobId, + spark, + "", // datasource is not known yet + "", // jobType is not known yet + "", // WP doesn't have sessionId + null, // WP doesn't use SessionManager + Duration.Inf, // WP doesn't have queryExecutionTimeout + -1, // WP doesn't have inactivityLimitMillis + -1, // WP doesn't have queryWaitTimeMillis + -1 // WP doesn't have queryLoopExecutionFrequency + ) + + registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount) + registerGauge(MetricConstants.STATEMENT_RUNNING_METRIC, statementRunningCount) + val statementExecutionManager = + instantiateStatementExecutionManager(commandContext) + + queryLoop(statementExecutionManager) + } + + /** + * Executes statements from the StatementExecutionManager in a loop until no more statements are + * available. + */ + def queryLoop(statementExecutionManager: StatementExecutionManager): Unit = { + var canProceed = true + + try { + while (canProceed) { + statementExecutionManager.getNextStatement() match { + case Some(flintStatement) => + flintStatement.running() + statementExecutionManager.updateStatement(flintStatement) + + val jobType = spark.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH) + val dataSource = spark.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key) + val resultIndex = spark.conf.get(FlintSparkConf.RESULT_INDEX.key) + val jobOperator = createJobOperator(flintStatement, dataSource, resultIndex, jobType) + + // The client sets this Spark configuration at runtime for each iteration + // to control whether the JVM should be terminated after the query execution. + jobOperator.terminateJVM = + spark.conf.get(FlintSparkConf.TERMINATE_JVM.key, "true").toBoolean + jobOperator.start() + + case _ => + canProceed = false + } + } + } catch { + case t: Throwable => + // Record and rethrow in query loop + throwableHandler.recordThrowable(s"Query loop execution failed.", t) + throw t + } + } + + def createJobOperator( + flintStatement: FlintStatement, + dataSource: String, + resultIndex: String, + jobType: String): JobOperator = { + // https://github.com/opensearch-project/opensearch-spark/issues/138 + /* + * To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`, + * it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain), + * and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table. + * By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly. + * Without this setup, Spark would not recognize names in the format `my_glue1.default`. + */ + spark.conf.set("spark.sql.defaultCatalog", dataSource) + val jobOperator = + JobOperator( + applicationId, + jobId, + spark, + flintStatement.query, + flintStatement.queryId, + dataSource, + resultIndex, + jobType, + streamingRunningCount, + statementRunningCount, + flintStatement.context) + jobOperator + } +} diff --git a/spark-sql-application/src/test/scala/org/apache/spark/sql/util/WarmpoolTest.scala b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/WarmpoolTest.scala new file mode 100644 index 000000000..381a84226 --- /dev/null +++ b/spark-sql-application/src/test/scala/org/apache/spark/sql/util/WarmpoolTest.scala @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.apache.spark.sql + +import java.time.Instant +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.ArgumentMatchers.{any, anyString} +import org.mockito.Mockito.{doAnswer, spy, times, verify, when} +import org.opensearch.flint.common.model.FlintStatement +import org.opensearch.flint.common.scheduler.model.LangType +import org.scalatestplus.mockito.MockitoSugar + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.flint.config.FlintSparkConf + +class WarmpoolTest extends SparkFunSuite with MockitoSugar with JobMatchers { + private val jobId = "testJobId" + private val applicationId = "testApplicationId" + val streamingRunningCount = new AtomicInteger(0) + val statementRunningCount = new AtomicInteger(0) + var mockStatementExecutionManager: StatementExecutionManager = _ + val resultIndex = "testResultIndex" + val dataSourceName = "my_glue1" + val requestIndex = "testRequestIndex" + + test("verify job operator starts twice when there are two Flint statements") { + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + val mockStatementExecutionManager = mock[StatementExecutionManager] + val mockJobOperator = mock[JobOperator] + + val firstFlintStatement = new FlintStatement( + "waiting", + "select 1", + "30", + "10", + LangType.SQL, + Instant.now().toEpochMilli(), + None) + + val secondFlintStatement = new FlintStatement( + "waiting", + "select * from DB", + "30", + "10", + LangType.SQL, + Instant.now().toEpochMilli(), + None) + + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockStatementExecutionManager.getNextStatement()) + .thenReturn(Some(firstFlintStatement)) + .thenReturn(Some(secondFlintStatement)) + .thenReturn(None) + + when(mockSparkSession.conf.get(FlintSparkConf.JOB_TYPE.key, FlintJobType.BATCH)) + .thenReturn(FlintJobType.BATCH) + when(mockSparkSession.conf.get(FlintSparkConf.DATA_SOURCE_NAME.key)) + .thenReturn(dataSourceName) + when(mockSparkSession.conf.get(FlintSparkConf.RESULT_INDEX.key)).thenReturn(resultIndex) + when(mockSparkSession.conf.get(FlintSparkConf.TERMINATE_JVM.key, "true")).thenReturn("true") + when(mockSparkSession.conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false")) + .thenReturn("true") + when(mockSparkSession.conf.get(FlintSparkConf.REQUEST_INDEX.key, "")) + .thenReturn(requestIndex) + + val job = spy( + WarmpoolJob( + applicationId, + jobId, + mockSparkSession, + statementRunningCount, + statementRunningCount)) + + doAnswer(_ => mockJobOperator) + .when(job) + .createJobOperator(any(), anyString(), anyString(), anyString()) + + job.queryLoop(mockStatementExecutionManager) + verify(mockJobOperator, times(2)).start() + } + + test("Query loop execution failure") { + val mockSparkSession = mock[SparkSession] + val mockConf = mock[RuntimeConfig] + val mockStatementExecutionManager = mock[StatementExecutionManager] + + when(mockSparkSession.conf).thenReturn(mockConf) + when(mockStatementExecutionManager.getNextStatement()) + .thenThrow(new RuntimeException("something went wrong")) + + val job = + WarmpoolJob( + applicationId, + jobId, + mockSparkSession, + statementRunningCount, + statementRunningCount) + + assertThrows[Throwable] { + job.queryLoop(mockStatementExecutionManager) + } + } +} From 83e24a4b0613ad55c4452927c37dd45d048afadd Mon Sep 17 00:00:00 2001 From: Shri Saran Raj N Date: Tue, 28 Jan 2025 14:39:30 +0530 Subject: [PATCH 2/2] Revert error message change Signed-off-by: Shri Saran Raj N --- .../src/main/scala/org/apache/spark/sql/JobOperator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala index dfd5e5759..153f22596 100644 --- a/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala +++ b/spark-sql-application/src/main/scala/org/apache/spark/sql/JobOperator.scala @@ -169,7 +169,7 @@ case class JobOperator( case t: Throwable => incrementCounter(MetricConstants.RESULT_WRITER_FAILED_METRIC) throwableHandler.recordThrowable( - s"Failed to write to result. originalError='${t.getMessage}'", + s"Failed to write to result. originalError='${throwableHandler.error}'", t) } finally { emitTimerMetric(MetricConstants.QUERY_RESULT_WRITER_TIME_METRIC, resultWriterStartTime)