Skip to content

Commit

Permalink
[SPARK-50666][SQL] Support hint for reading in JDBC data source
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR aims to add a hint option for JDBC data source. This option is used to specify the hint for reading. It will apply only if the underlying DBMS supports the hint feature. Currently, this option is only supported by OracleDialect and MySQLDialect.

### Why are the changes needed?

It's useful for performance tuning when reading from DBMS.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Passed GA and add a new test case.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #49564 from wayneguow/jdbc_hint.

Authored-by: Wei Guo <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
  • Loading branch information
wayneguow authored and dongjoon-hyun committed Jan 27, 2025
1 parent b5deb8d commit fef1b23
Show file tree
Hide file tree
Showing 14 changed files with 138 additions and 8 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1701,6 +1701,12 @@
],
"sqlState" : "42822"
},
"HINT_UNSUPPORTED_FOR_JDBC_DIALECT" : {
"message" : [
"The option `hint` is not supported for <jdbcDialect> in JDBC data source. Supported dialects are `MySQLDialect`, `OracleDialect` and `DatabricksDialect`."
],
"sqlState" : "42822"
},
"HLL_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)."
Expand Down
8 changes: 8 additions & 0 deletions docs/sql-data-sources-jdbc.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ logging into the data sources.
</td>
<td>read</td>
</tr>
<tr>
<td><code>hint</code></td>
<td>(none)</td>
<td>
This option is used to specify the hint for reading. The supported hint format is a variant of C-style comments: it needs to start with `/*+ ` and end with ` */`. Currently, this option is only supported in MySQLDialect, OracleDialect and DatabricksDialect.
</td>
<td>read</td>
</tr>
</table>

Note that kerberos authentication with keytab is not always supported by the JDBC driver.<br>
Expand Down
21 changes: 21 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@
<db2.jcc.version>11.5.9.0</db2.jcc.version>
<mssql.jdbc.version>12.8.1.jre11</mssql.jdbc.version>
<ojdbc17.version>23.6.0.24.10</ojdbc17.version>
<databricks.jdbc.version>2.7.1</databricks.jdbc.version>
<snowflake.jdbc.version>3.21.0</snowflake.jdbc.version>
<terajdbc.version>20.00.00.39</terajdbc.version>
<!-- Used for SBT build to retrieve the Spark version -->
<spark.version>${project.version}</spark.version>
</properties>
Expand Down Expand Up @@ -1350,6 +1353,24 @@
<version>${ojdbc17.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.databricks</groupId>
<artifactId>databricks-jdbc</artifactId>
<version>${databricks.jdbc.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>snowflake-jdbc</artifactId>
<version>${snowflake.jdbc.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.teradata.jdbc</groupId>
<artifactId>terajdbc</artifactId>
<version>${terajdbc.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.curator</groupId>
<artifactId>curator-recipes</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
messageParameters = Map("content" -> content))
}

def hintUnsupportedForJdbcDialectError(jdbcDialect: String): SparkIllegalArgumentException = {
new SparkIllegalArgumentException(
errorClass = "HINT_UNSUPPORTED_FOR_JDBC_DIALECT",
messageParameters = Map("jdbcDialect" -> jdbcDialect))
}

def unsupportedArrayElementTypeBasedOnBinaryError(dt: DataType): SparkIllegalArgumentException = {
new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_2084",
Expand Down
15 changes: 15 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@
<artifactId>derbytools</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.databricks</groupId>
<artifactId>databricks-jdbc</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>net.snowflake</groupId>
<artifactId>snowflake-jdbc</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.teradata.jdbc</groupId>
<artifactId>terajdbc</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.parquet</groupId>
<artifactId>parquet-avro</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,15 @@ class JDBCOptions(
.map(_.toBoolean)
.getOrElse(SQLConf.get.timestampType == TimestampNTZType)

val hint = {
parameters.get(JDBC_HINT_STRING).map(value => {
require(value.matches("(?s)^/\\*\\+ .* \\*/$"),
s"Invalid value `$value` for option `$JDBC_HINT_STRING`." +
s" It should start with `/*+ ` and end with ` */`.")
s"$value "
}).getOrElse("")
}

override def hashCode: Int = this.parameters.hashCode()

override def equals(other: Any): Boolean = other match {
Expand Down Expand Up @@ -321,4 +330,5 @@ object JDBCOptions {
val JDBC_CONNECTION_PROVIDER = newOption("connectionProvider")
val JDBC_PREPARE_QUERY = newOption("prepareQuery")
val JDBC_PREFER_TIMESTAMP_NTZ = newOption("preferTimestampNTZ")
val JDBC_HINT_STRING = newOption("hint")
}
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ private case class DB2Dialect() extends JdbcDialect with SQLConfHelper with NoLe
val offsetClause = dialect.getOffsetClause(offset)

options.prepareQuery +
s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s"SELECT $hintClause$columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause $offsetClause $limitClause"
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ private case class DatabricksDialect() extends JdbcDialect with NoLegacyJDBCErro
s"TABLESAMPLE (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})"
}

override def supportsHint: Boolean = true

// Override listSchemas to run "show schemas" as a PreparedStatement instead of
// invoking getMetaData.getSchemas as it may not work correctly in older versions of the driver.
override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,8 @@ abstract class JdbcDialect extends Serializable with Logging {
def getTableSample(sample: TableSampleInfo): String =
throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3183")

def supportsHint: Boolean = false

/**
* Return the DB-specific quoted and fully qualified table name
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.jdbc

import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition}
import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo

Expand Down Expand Up @@ -67,6 +68,18 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) {
*/
protected var tableSampleClause: String = ""

/**
* A hint sample clause representing query hints.
*/
protected val hintClause: String = {
if (options.hint == "" || dialect.supportsHint) {
options.hint
} else {
throw QueryExecutionErrors.hintUnsupportedForJdbcDialectError(
dialect.getClass.getSimpleName)
}
}

/**
* The columns names that following dialect's SQL syntax.
* e.g. The column name is the raw name or quoted name.
Expand Down Expand Up @@ -161,7 +174,7 @@ class JdbcSQLQueryBuilder(dialect: JdbcDialect, options: JDBCOptions) {
val offsetClause = dialect.getOffsetClause(offset)

options.prepareQuery +
s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s"SELECT $hintClause$columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause $limitClause $offsetClause"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ private case class MsSqlServerDialect() extends JdbcDialect with NoLegacyJDBCErr
val limitClause = dialect.getLimitClause(limit)

options.prepareQuery +
s"SELECT $limitClause $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause"
s"SELECT $hintClause$limitClause $columnList FROM ${options.tableOrQuery}" +
s" $tableSampleClause $whereClause $groupByClause $orderByClause"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
}

options.prepareQuery +
s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s"SELECT $hintClause$columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause $limitOrOffsetStmt"
}
}
Expand All @@ -417,4 +417,6 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No
override def supportsLimit: Boolean = true

override def supportsOffset: Boolean = true

override def supportsHint: Boolean = true
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N
extends JdbcSQLQueryBuilder(dialect, options) {

override def build(): String = {
val selectStmt = s"SELECT $columnList FROM ${options.tableOrQuery} $tableSampleClause" +
s" $whereClause $groupByClause $orderByClause"
val selectStmt = s"SELECT $hintClause$columnList FROM ${options.tableOrQuery}" +
s" $tableSampleClause $whereClause $groupByClause $orderByClause"
val finalSelectStmt = if (limit > 0) {
if (offset > 0) {
// Because the rownum is calculated when the value is returned,
Expand Down Expand Up @@ -255,6 +255,8 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N

override def supportsOffset: Boolean = true

override def supportsHint: Boolean = true

override def classifyException(
e: Throwable,
errorClass: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import scala.util.Random
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._

import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkSQLException}
import org.apache.spark.sql.{AnalysisException, DataFrame, Observation, QueryTest, Row}
import org.apache.spark.sql.catalyst.{analysis, TableIdentifier}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
Expand Down Expand Up @@ -2206,4 +2206,47 @@ class JDBCSuite extends QueryTest with SharedSparkSession {
JdbcUtils.schemaString(dialect, schema, caseSensitive = false, Some("b boolean"))
assert(schemaStr === """"b" NUMBER(1) """)
}

test("SPARK-50666: reading hint test") {
// hint format check
Seq("INDEX(test idx1) */", "/*+ INDEX(test idx1)", "").foreach { hint =>
val e = intercept[IllegalArgumentException] {
val options = new JDBCOptions(Map("url" -> url, "dbtable" -> "test",
"hint" -> hint))
}.getMessage
assert(e.contains(s"Invalid value `$hint` for option `hint`." +
s" It should start with `/*+ ` and end with ` */`."))
}

// dialect supported check
val baseParameters = CaseInsensitiveMap(
Map("dbtable" -> "test", "hint" -> "/*+ INDEX(test idx1) */"))
// supported
Seq(
"jdbc:oracle:thin:@//host:port",
"jdbc:mysql://host:port",
"jdbc:databricks://host:port").foreach { url =>
val options = new JDBCOptions(baseParameters + ("url" -> url))
val dialect = JdbcDialects.get(url)
assert(dialect.getJdbcSQLQueryBuilder(options)
.withColumns(Array("a", "b"))
.build().trim() == "SELECT /*+ INDEX(test idx1) */ a,b FROM test")
}
// not supported
Seq(
"jdbc:db2://host:port", "jdbc:derby:memory", "jdbc:h2://host:port",
"jdbc:sqlserver://host:port", "jdbc:postgresql://host:5432/postgres",
"jdbc:snowflake://host:443?account=test", "jdbc:teradata://host:port").foreach { url =>
val options = new JDBCOptions(baseParameters + ("url" -> url))
val dialect = JdbcDialects.get(url)
checkError(
exception = intercept[SparkIllegalArgumentException] {
dialect.getJdbcSQLQueryBuilder(options)
.withColumns(Array("a", "b"))
.build().trim()
},
condition = "HINT_UNSUPPORTED_FOR_JDBC_DIALECT",
parameters = Map("jdbcDialect" -> dialect.getClass.getSimpleName))
}
}
}

0 comments on commit fef1b23

Please sign in to comment.