Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FlintJob to handle all query types in warmpool mode #979

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ public final class MetricConstants {
*/
public static final String RESULT_METADATA_WRITE_METRIC_PREFIX = "result.metadata.write";

/**
* Prefix for metrics related to interactive queries
*/
public static final String STATEMENT = "statement";

/**
* Metric name for counting the number of statements currently running.
*/
Expand Down Expand Up @@ -135,11 +140,31 @@ public final class MetricConstants {
*/
public static final String STREAMING_HEARTBEAT_FAILED_METRIC = "streaming.heartbeat.failed.count";

/**
* Metric for tracking the count of jobs failed during query execution
*/
public static final String QUERY_EXECUTION_FAILED_METRIC = "execution.failed.count";

/**
* Metric for tracking the count of jobs failed during query result write
*/
public static final String RESULT_WRITER_FAILED_METRIC = "writer.failed.count";

/**
* Metric for tracking the latency of query execution (start to complete query execution) excluding result write.
*/
public static final String QUERY_EXECUTION_TIME_METRIC = "query.execution.processingTime";

/**
* Metric for tracking the latency of query result write only (excluding query execution)
*/
public static final String QUERY_RESULT_WRITER_TIME_METRIC = "result.writer.processingTime";

/**
* Metric for tracking the latency of query total execution including result write.
*/
public static final String QUERY_TOTAL_TIME_METRIC = "query.total.processingTime";

/**
* Metric for query count of each query type (DROP/VACUUM/ALTER/REFRESH/CREATE INDEX)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ object FlintSparkConf {
.doc("Enable external scheduler for index refresh")
.createWithDefault("false")

val WARMPOOL_ENABLED =
FlintConfig("spark.flint.job.warmpoolEnabled")
.createWithDefault("false")

val MAX_EXECUTORS_COUNT = FlintConfig("spark.dynamicAllocation.maxExecutors").createOptional()

val EXTERNAL_SCHEDULER_INTERVAL_THRESHOLD =
FlintConfig("spark.flint.job.externalScheduler.interval")
.doc("Interval threshold in minutes for external scheduler to trigger index refresh")
Expand Down Expand Up @@ -246,6 +252,10 @@ object FlintSparkConf {
FlintConfig(s"spark.flint.job.requestIndex")
.doc("Request index")
.createOptional()
val RESULT_INDEX =
FlintConfig(s"spark.flint.job.resultIndex")
.doc("Result index")
.createOptional()
val EXCLUDE_JOB_IDS =
FlintConfig(s"spark.flint.deployment.excludeJobs")
.doc("Exclude job ids")
Expand All @@ -271,6 +281,9 @@ object FlintSparkConf {
val CUSTOM_QUERY_RESULT_WRITER =
FlintConfig("spark.flint.job.customQueryResultWriter")
.createOptional()
val TERMINATE_JVM = FlintConfig("spark.flint.terminateJVM")
.doc("Indicates whether the JVM should be terminated after query execution")
.createWithDefault("true")
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
val appId = "00feq82b752mbt0p"
val dataSourceName = "my_glue1"
val queryId = "testQueryId"
val requestIndex = "testRequestIndex"
var osClient: OSClient = _
val threadLocalFuture = new ThreadLocal[Future[Unit]]()

Expand Down Expand Up @@ -83,13 +84,15 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {

def createJobOperator(query: String, jobRunId: String): JobOperator = {
val streamingRunningCount = new AtomicInteger(0)
val statementRunningCount = new AtomicInteger(0)

/*
* Because we cannot test from FlintJob.main() for the reason below, we have to configure
* all Spark conf required by Flint code underlying manually.
*/
spark.conf.set(DATA_SOURCE_NAME.key, dataSourceName)
spark.conf.set(JOB_TYPE.key, FlintJobType.STREAMING)
spark.conf.set(REQUEST_INDEX.key, requestIndex)

val job = JobOperator(
appId,
Expand All @@ -100,7 +103,8 @@ class FlintJobITSuite extends FlintSparkSuite with JobTest {
dataSourceName,
resultIndex,
FlintJobType.STREAMING,
streamingRunningCount)
streamingRunningCount,
statementRunningCount)
job.terminateJVM = false
job
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{ExecutionContext, ExecutionContextExecutor}

import org.opensearch.flint.core.logging.CustomLogging
import org.opensearch.flint.core.metrics.MetricConstants
import org.opensearch.flint.core.metrics.MetricsUtil.registerGauge

import org.apache.spark.internal.Logging
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.util.ThreadUtils

/**
* Spark SQL Application entrypoint
Expand All @@ -26,52 +29,71 @@ import org.apache.spark.sql.flint.config.FlintSparkConf
* write sql query result to given opensearch index
*/
object FlintJob extends Logging with FlintJobExecutor {
private val streamingRunningCount = new AtomicInteger(0)
private val statementRunningCount = new AtomicInteger(0)

def main(args: Array[String]): Unit = {
val (queryOption, resultIndexOption) = parseArgs(args)

val conf = createSparkConf()
val jobType = conf.get("spark.flint.job.type", FlintJobType.BATCH)
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val sparkSession = createSparkSession(conf)
val applicationId =
environmentProvider.getEnvVar("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", "unknown")
val jobId = environmentProvider.getEnvVar("SERVERLESS_EMR_JOB_ID", "unknown")
val isWarmpoolEnabled = conf.get(FlintSparkConf.WARMPOOL_ENABLED.key, "false").toBoolean
logInfo(s"isWarmpoolEnabled: ${isWarmpoolEnabled}")

if (!isWarmpoolEnabled) {
val jobType = sparkSession.conf.get("spark.flint.job.type", FlintJobType.BATCH)
saranrajnk marked this conversation as resolved.
Show resolved Hide resolved
CustomLogging.logInfo(s"""Job type is: ${jobType}""")
sparkSession.conf.set(FlintSparkConf.JOB_TYPE.key, jobType)

val dataSource = conf.get("spark.flint.datasource.name", "")
saranrajnk marked this conversation as resolved.
Show resolved Hide resolved
val query = queryOption.getOrElse(unescapeQuery(conf.get(FlintSparkConf.QUERY.key, "")))
if (query.isEmpty) {
logAndThrow(s"Query undefined for the ${jobType} job.")
}
val queryId = conf.get(FlintSparkConf.QUERY_ID.key, "")

val streamingRunningCount = new AtomicInteger(0)
val jobOperator =
JobOperator(
applicationId,
jobId,
createSparkSession(conf),
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
if (resultIndexOption.isEmpty) {
logAndThrow("resultIndex is not set")
}
// https://github.com/opensearch-project/opensearch-spark/issues/138
/*
* To execute queries such as `CREATE SKIPPING INDEX ON my_glue1.default.http_logs_plain (`@timestamp` VALUE_SET) WITH (auto_refresh = true)`,
* it's necessary to set `spark.sql.defaultCatalog=my_glue1`. This is because AWS Glue uses a single database (default) and table (http_logs_plain),
* and we need to configure Spark to recognize `my_glue1` as a reference to AWS Glue's database and table.
* By doing this, we effectively map `my_glue1` to AWS Glue, allowing Spark to resolve the database and table names correctly.
* Without this setup, Spark would not recognize names in the format `my_glue1.default`.
*/
conf.set("spark.sql.defaultCatalog", dataSource)
configDYNMaxExecutors(conf, jobType)

val jobOperator =
JobOperator(
applicationId,
jobId,
sparkSession,
query,
queryId,
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount,
statementRunningCount,
Map.empty)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
// Fetch and execute queries in warm pool mode
val warmpoolJob =
WarmpoolJob(
applicationId,
jobId,
sparkSession,
streamingRunningCount,
statementRunningCount)
warmpoolJob.start()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.apache.spark.sql

import java.util.Locale
import java.util.concurrent.ThreadPoolExecutor

import com.amazonaws.services.glue.model.{AccessDeniedException, AWSGlueException}
import com.amazonaws.services.s3.model.AmazonS3Exception
Expand All @@ -20,6 +21,7 @@ import play.api.libs.json._

import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.FlintREPL.instantiate
import org.apache.spark.sql.SparkConfConstants.{DEFAULT_SQL_EXTENSIONS, SQL_EXTENSIONS_KEY}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.exception.UnrecoverableException
Expand Down Expand Up @@ -566,4 +568,31 @@ trait FlintJobExecutor {
}
}
}

def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ object FlintREPL extends Logging with FlintJobExecutor {
dataSource,
resultIndexOption.get,
jobType,
streamingRunningCount)
streamingRunningCount,
statementRunningCount)
registerGauge(MetricConstants.STREAMING_RUNNING_METRIC, streamingRunningCount)
jobOperator.start()
} else {
Expand Down Expand Up @@ -1021,33 +1022,6 @@ object FlintREPL extends Logging with FlintJobExecutor {
}
}

private def instantiateSessionManager(
spark: SparkSession,
resultIndexOption: Option[String]): SessionManager = {
instantiate(
new SessionManagerImpl(spark, resultIndexOption),
spark.conf.get(FlintSparkConf.CUSTOM_SESSION_MANAGER.key, ""),
resultIndexOption.getOrElse(""))
}

private def instantiateStatementExecutionManager(
commandContext: CommandContext): StatementExecutionManager = {
import commandContext._
instantiate(
new StatementExecutionManagerImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_STATEMENT_MANAGER.key, ""),
spark,
sessionId)
}

private def instantiateQueryResultWriter(
spark: SparkSession,
commandContext: CommandContext): QueryResultWriter = {
instantiate(
new QueryResultWriterImpl(commandContext),
spark.conf.get(FlintSparkConf.CUSTOM_QUERY_RESULT_WRITER.key, ""))
}

private def recordSessionSuccess(sessionTimerContext: Timer.Context): Unit = {
logInfo("Session Success")
stopTimer(sessionTimerContext)
Expand Down
Loading
Loading