Skip to content

Commit edfd357

Browse files
pan3793LuciferYang
authored andcommitted
[SPARK-54108][CONNECT] Revise execute* methods of SparkConnectStatement
### What changes were proposed in this pull request? This PR revises the following 3 `execute*` methods and one additional `getUpdateCount` method of `SparkConnectStatement` that are defined in `java.sql.Statement` ```java /** * Executes the given SQL statement, which returns a single * {code ResultSet} object. *<p> * <strong>Note:</strong>This method cannot be called on a * {code PreparedStatement} or {code CallableStatement}. * param sql an SQL statement to be sent to the database, typically a * static SQL {code SELECT} statement * return a {code ResultSet} object that contains the data produced * by the given query; never {code null} * throws SQLException if a database access error occurs, * this method is called on a closed {code Statement}, the given * SQL statement produces anything other than a single * {code ResultSet} object, the method is called on a * {code PreparedStatement} or {code CallableStatement} * throws SQLTimeoutException when the driver has determined that the * timeout value that was specified by the {code setQueryTimeout} * method has been exceeded and has at least attempted to cancel * the currently running {code Statement} */ ResultSet executeQuery(String sql) throws SQLException; /** * Executes the given SQL statement, which may be an {code INSERT}, * {code UPDATE}, or {code DELETE} statement or an * SQL statement that returns nothing, such as an SQL DDL statement. *<p> * <strong>Note:</strong>This method cannot be called on a * {code PreparedStatement} or {code CallableStatement}. * param sql an SQL Data Manipulation Language (DML) statement, such as {code INSERT}, {code UPDATE} or * {code DELETE}; or an SQL statement that returns nothing, * such as a DDL statement. * * return either (1) the row count for SQL Data Manipulation Language (DML) statements * or (2) 0 for SQL statements that return nothing * * throws SQLException if a database access error occurs, * this method is called on a closed {code Statement}, the given * SQL statement produces a {code ResultSet} object, the method is called on a * {code PreparedStatement} or {code CallableStatement} * throws SQLTimeoutException when the driver has determined that the * timeout value that was specified by the {code setQueryTimeout} * method has been exceeded and has at least attempted to cancel * the currently running {code Statement} */ int executeUpdate(String sql) throws SQLException; /** * Executes the given SQL statement, which may return multiple results. * In some (uncommon) situations, a single SQL statement may return * multiple result sets and/or update counts. Normally you can ignore * this unless you are (1) executing a stored procedure that you know may * return multiple results or (2) you are dynamically executing an * unknown SQL string. * <P> * The {code execute} method executes an SQL statement and indicates the * form of the first result. You must then use the methods * {code getResultSet} or {code getUpdateCount} * to retrieve the result, and {code getMoreResults} to * move to any subsequent result(s). * <p> *<strong>Note:</strong>This method cannot be called on a * {code PreparedStatement} or {code CallableStatement}. * param sql any SQL statement * return {code true} if the first result is a {code ResultSet} * object; {code false} if it is an update count or there are * no results * throws SQLException if a database access error occurs, * this method is called on a closed {code Statement}, * the method is called on a * {code PreparedStatement} or {code CallableStatement} * throws SQLTimeoutException when the driver has determined that the * timeout value that was specified by the {code setQueryTimeout} * method has been exceeded and has at least attempted to cancel * the currently running {code Statement} * see #getResultSet * see #getUpdateCount * see #getMoreResults */ boolean execute(String sql) throws SQLException; /** * Retrieves the current result as an update count; * if the result is a {code ResultSet} object or there are no more results, -1 * is returned. This method should be called only once per result. * * return the current result as an update count; -1 if the current result is a * {code ResultSet} object or there are no more results * throws SQLException if a database access error occurs or * this method is called on a closed {code Statement} * see #execute */ int getUpdateCount() throws SQLException; ``` ### Why are the changes needed? Make the implementation respect the JDBC API specification. ### Does this PR introduce _any_ user-facing change? No, Connect JDBC Driver is an unreleased feature. ### How was this patch tested? New UTs are added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #52810 from pan3793/SPARK-54108. Authored-by: Cheng Pan <[email protected]> Signed-off-by: yangjie01 <[email protected]>
1 parent 2be1eb7 commit edfd357

File tree

2 files changed

+127
-20
lines changed

2 files changed

+127
-20
lines changed

sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.connect.client.jdbc
1919

2020
import java.sql.{Array => _, _}
2121

22+
import org.apache.spark.sql.connect.client.SparkResult
23+
2224
class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
2325

2426
private var operationId: String = _
@@ -49,33 +51,51 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
4951
}
5052

5153
override def executeQuery(sql: String): ResultSet = {
52-
checkOpen()
53-
54-
val df = conn.spark.sql(sql)
55-
val sparkResult = df.collectResult()
56-
operationId = sparkResult.operationId
57-
resultSet = new SparkConnectResultSet(sparkResult, this)
58-
resultSet
54+
val hasResultSet = execute(sql)
55+
if (hasResultSet) {
56+
assert(resultSet != null)
57+
resultSet
58+
} else {
59+
throw new SQLException("The query does not produce a ResultSet.")
60+
}
5961
}
6062

6163
override def executeUpdate(sql: String): Int = {
62-
checkOpen()
63-
64-
val df = conn.spark.sql(sql)
65-
val sparkResult = df.collectResult()
66-
operationId = sparkResult.operationId
67-
resultSet = null
64+
val hasResultSet = execute(sql)
65+
if (hasResultSet) {
66+
// user are not expected to access the result set in this case,
67+
// we must close it to avoid memory leak.
68+
resultSet.close()
69+
throw new SQLException("The query produces a ResultSet.")
70+
} else {
71+
assert(resultSet == null)
72+
getUpdateCount
73+
}
74+
}
6875

69-
// always return 0 because affected rows is not supported yet
70-
0
76+
private def hasResultSet(sparkResult: SparkResult[_]): Boolean = {
77+
// suppose this works in most cases
78+
sparkResult.schema.length > 0
7179
}
7280

7381
override def execute(sql: String): Boolean = {
7482
checkOpen()
7583

76-
// always perform executeQuery and reture a ResultSet
77-
executeQuery(sql)
78-
true
84+
// stmt can be reused to execute more than one queries,
85+
// reset before executing new query
86+
operationId = null
87+
resultSet = null
88+
89+
val df = conn.spark.sql(sql)
90+
val sparkResult = df.collectResult()
91+
operationId = sparkResult.operationId
92+
if (hasResultSet(sparkResult)) {
93+
resultSet = new SparkConnectResultSet(sparkResult, this)
94+
true
95+
} else {
96+
sparkResult.close()
97+
false
98+
}
7999
}
80100

81101
override def getResultSet: ResultSet = {
@@ -123,8 +143,15 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
123143
override def setCursorName(name: String): Unit =
124144
throw new SQLFeatureNotSupportedException
125145

126-
override def getUpdateCount: Int =
127-
throw new SQLFeatureNotSupportedException
146+
override def getUpdateCount: Int = {
147+
checkOpen()
148+
149+
if (resultSet != null) {
150+
-1
151+
} else {
152+
0 // always return 0 because affected rows is not supported yet
153+
}
154+
}
128155

129156
override def getMoreResults: Boolean =
130157
throw new SQLFeatureNotSupportedException
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.connect.client.jdbc
19+
20+
import java.sql.{Array => _, _}
21+
22+
import scala.util.Using
23+
24+
import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper
25+
import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper}
26+
27+
class SparkConnectStatementSuite extends ConnectFunSuite with RemoteSparkSession
28+
with JdbcHelper with SQLHelper {
29+
30+
override def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort"
31+
32+
test("returned result set and update count of execute* methods") {
33+
withTable("t1", "t2", "t3") {
34+
withStatement { stmt =>
35+
// CREATE TABLE
36+
assert(!stmt.execute("CREATE TABLE t1 (id INT) USING Parquet"))
37+
assert(stmt.getUpdateCount === 0)
38+
assert(stmt.getResultSet === null)
39+
40+
var se = intercept[SQLException] {
41+
stmt.executeQuery("CREATE TABLE t2 (id INT) USING Parquet")
42+
}
43+
assert(se.getMessage === "The query does not produce a ResultSet.")
44+
45+
assert(stmt.executeUpdate("CREATE TABLE t3 (id INT) USING Parquet") === 0)
46+
assert(stmt.getResultSet === null)
47+
48+
// INSERT INTO
49+
assert(!stmt.execute("INSERT INTO t1 VALUES (1)"))
50+
assert(stmt.getUpdateCount === 0)
51+
assert(stmt.getResultSet === null)
52+
53+
se = intercept[SQLException] {
54+
stmt.executeQuery("INSERT INTO t1 VALUES (1)")
55+
}
56+
assert(se.getMessage === "The query does not produce a ResultSet.")
57+
58+
assert(stmt.executeUpdate("INSERT INTO t1 VALUES (1)") === 0)
59+
assert(stmt.getResultSet === null)
60+
61+
// SELECT
62+
assert(stmt.execute("SELECT id FROM t1"))
63+
assert(stmt.getUpdateCount === -1)
64+
Using.resource(stmt.getResultSet) { rs =>
65+
assert(rs !== null)
66+
}
67+
68+
Using.resource(stmt.executeQuery("SELECT id FROM t1")) { rs =>
69+
assert(stmt.getUpdateCount === -1)
70+
assert(rs !== null)
71+
}
72+
73+
se = intercept[SQLException] {
74+
stmt.executeUpdate("SELECT id FROM t1")
75+
}
76+
assert(se.getMessage === "The query produces a ResultSet.")
77+
}
78+
}
79+
}
80+
}

0 commit comments

Comments
 (0)