diff --git a/externals/kyuubi-spark-sql-engine/pom.xml b/externals/kyuubi-spark-sql-engine/pom.xml index e0621c7fd28..dc98da1b725 100644 --- a/externals/kyuubi-spark-sql-engine/pom.xml +++ b/externals/kyuubi-spark-sql-engine/pom.xml @@ -115,6 +115,12 @@ test + + org.apache.spark + spark-hive_${scala.binary.version} + provided + + org.apache.kyuubi ${hive.jdbc.artifact} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteScala.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteScala.scala index 092e6e8241c..5e5a74c0b91 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteScala.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/kyuubi/engine/spark/operation/ExecuteScala.scala @@ -25,7 +25,9 @@ import scala.tools.nsc.interpreter.Results.{Error, Incomplete, Success} import org.apache.hadoop.fs.Path import org.apache.spark.SparkFiles +import org.apache.spark.kyuubi.SparkJobArtifactHelper import org.apache.spark.sql.Row +import org.apache.spark.sql.hive.HiveClientHelper import org.apache.spark.sql.types.StructType import org.apache.kyuubi.KyuubiSQLException @@ -90,13 +92,14 @@ class ExecuteScala( warn(s"Clearing legacy output from last interpreting:\n $legacyOutput") } val replUrls = repl.classLoader.getParent.asInstanceOf[URLClassLoader].getURLs - spark.sharedState.jarClassLoader.getURLs.filterNot(replUrls.contains).foreach { jar => + val root = new File(SparkFiles.getRootDirectory(), session.handle.identifier.toString) + HiveClientHelper.getLoadedClasses(spark).filterNot(replUrls.contains).foreach { jar => try { if ("file".equals(jar.toURI.getScheme)) { repl.addUrlsToClassPath(jar) } else { spark.sparkContext.addFile(jar.toString) - val localJarFile = new File(SparkFiles.get(new Path(jar.toURI.getPath).getName)) + val localJarFile = new File(root, new Path(jar.toURI.getPath).getName) val localJarUrl = localJarFile.toURI.toURL if (!replUrls.contains(localJarUrl)) { repl.addUrlsToClassPath(localJarUrl) @@ -140,7 +143,9 @@ class ExecuteScala( val asyncOperation = new Runnable { override def run(): Unit = { OperationLog.setCurrentOperationLog(operationLog) - executeScala() + withSessionArtifactState { + executeScala() + } } } @@ -157,7 +162,15 @@ class ExecuteScala( throw ke } } else { - executeScala() + withSessionArtifactState { + executeScala() + } + } + } + + private def withSessionArtifactState(f: => Unit): Unit = { + SparkJobArtifactHelper.withActiveJobArtifactState(session.handle) { + f } } } diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkJobArtifactHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkJobArtifactHelper.scala new file mode 100644 index 00000000000..a862dbc6475 --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/kyuubi/SparkJobArtifactHelper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.kyuubi + +import org.apache.spark.{JobArtifactSet, JobArtifactState} + +import org.apache.kyuubi.session.SessionHandle + +object SparkJobArtifactHelper { + + def withActiveJobArtifactState(handler: SessionHandle)(f: => Unit): Unit = { + val state = JobArtifactState(handler.identifier.toString, None) + JobArtifactSet.withActiveJobArtifactState(state) { + f + } + } +} diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/hive/HiveClientHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/hive/HiveClientHelper.scala new file mode 100644 index 00000000000..9e252cf3f4b --- /dev/null +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/hive/HiveClientHelper.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.net.URL + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SessionResourceLoader +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION + +object HiveClientHelper { + + type HiveClientImpl = org.apache.spark.sql.hive.client.HiveClientImpl + + def getLoadedClasses(spark: SparkSession): Array[URL] = { + if (spark.conf.get(CATALOG_IMPLEMENTATION).equals("hive")) { + val loader = spark.sessionState.resourceLoader + getHiveLoadedClasses(loader) + } else { + spark.sharedState.jarClassLoader.getURLs + } + } + + private def getHiveLoadedClasses(loader: SessionResourceLoader): Array[URL] = { + if (loader != null) { + val field = classOf[HiveSessionResourceLoader].getDeclaredField("client") + field.setAccessible(true) + val client = field.get(loader).asInstanceOf[HiveClientImpl] + if (client != null) { + client.clientLoader.classLoader.getURLs + } else { + Array.empty + } + } else { + Array.empty + } + } +} diff --git a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala index de491e03f21..e98418e9505 100644 --- a/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala +++ b/kyuubi-server/src/test/scala/org/apache/kyuubi/operation/KyuubiOperationPerUserSuite.scala @@ -203,6 +203,17 @@ class KyuubiOperationPerUserSuite } test("scala NPE issue with hdfs jar") { + val dfsJarPath = prepareHdfsJar + withJdbcStatement() { statement => + val kyuubiStatement = statement.asInstanceOf[KyuubiStatement] + statement.executeQuery(s"add jar $dfsJarPath") + val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))") + rs.next() + assert(rs.getString(1) === "3") + } + } + + private def prepareHdfsJar: Path = { val jarDir = Utils.createTempDir().toFile val udfCode = """ @@ -225,12 +236,44 @@ class KyuubiOperationPerUserSuite val localPath = new Path(jarFile.getAbsolutePath) val dfsJarPath = new Path(dfsJarDir, "test-function.jar") FileUtil.copy(localFs, localPath, dfs, dfsJarPath, false, false, hadoopConf) - withJdbcStatement() { statement => - val kyuubiStatement = statement.asInstanceOf[KyuubiStatement] - statement.executeQuery(s"add jar $dfsJarPath") - val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))") - rs.next() - assert(rs.getString(1) === "3") + dfsJarPath + } + + test("support scala mode resource isolation") { + val dfsJarPath = prepareHdfsJar + withSessionConf()( + Map( + KyuubiConf.ENGINE_SHARE_LEVEL_SUBDOMAIN.key -> "resource_isolation", + "spark.sql.catalogImplementation" -> "hive"))( + Map.empty) { + var r1: String = null + var exception: Exception = null + + new Thread { + override def run(): Unit = withJdbcStatement() { statement => + val kyuubiStatement = statement.asInstanceOf[KyuubiStatement] + kyuubiStatement.executeQuery(s"add jar $dfsJarPath") + val rs = kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))") + rs.next() + r1 = rs.getString(1) + } + }.start() + + new Thread { + override def run(): Unit = withJdbcStatement() { statement => + val kyuubiStatement = statement.asInstanceOf[KyuubiStatement] + exception = intercept[Exception] { + kyuubiStatement.executeScala("println(test.utils.Math.add(1,2))") + } + } + }.start() + + eventually(timeout(120.seconds), interval(100.milliseconds)) { + assert(r1 != null && exception != null) + } + + assert(r1 === "3") + assert(exception.getMessage.contains("not found: value test")) } }