From f64da4728c1eac1fb4b777595cbd7a051556ea86 Mon Sep 17 00:00:00 2001 From: Mark Hamilton Date: Fri, 20 Oct 2023 17:09:41 -0400 Subject: [PATCH] fix: handle long run ids in adb tests (#2102) --- .../ml/nbtest/DatabricksCPUTests.scala | 2 +- .../ml/nbtest/DatabricksGPUTests.scala | 2 +- .../ml/nbtest/DatabricksUtilities.scala | 50 ++++++++++--------- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksCPUTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksCPUTests.scala index c82f40a6a8..7098114a33 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksCPUTests.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksCPUTests.scala @@ -11,7 +11,7 @@ import scala.language.existentials class DatabricksCPUTests extends DatabricksTestHelper { val clusterId: String = createClusterInPool(ClusterName, AdbRuntime, NumWorkers, PoolId, "[]") - val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(clusterId, Libraries, CPUNotebooks) + val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper(clusterId, Libraries, CPUNotebooks) protected override def afterAll(): Unit = { afterAllHelper(jobIdsToCancel, clusterId, ClusterName) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala index be308c7af7..ccc3e58ce2 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksGPUTests.scala @@ -16,7 +16,7 @@ class DatabricksGPUTests extends DatabricksTestHelper { "src", "main", "python", "horovod_installation.sh").getCanonicalFile uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh") val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts) - val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper( + val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper( clusterId, GPULibraries, GPUNotebooks) protected override def afterAll(): Unit = { diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala index 43f5203324..5b1c24a65f 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/nbtest/DatabricksUtilities.scala @@ -39,7 +39,7 @@ object DatabricksUtilities { lazy val Token: String = sys.env.getOrElse("MML_ADB_TOKEN", Secrets.AdbToken) lazy val AuthValue: String = "Basic " + BaseEncoding.base64() .encode(("token:" + Token).getBytes("UTF-8")) - val BaseURL = s"https://$Region.azuredatabricks.net/api/2.0/" + lazy val PoolId: String = getPoolIdByName(PoolName) lazy val GpuPoolId: String = getPoolIdByName(GpuPoolName) lazy val ClusterName = s"mmlspark-build-${LocalDateTime.now()}" @@ -67,6 +67,8 @@ object DatabricksUtilities { "interpret-community" ) + def baseURL(apiVersion: String): String = s"https://$Region.azuredatabricks.net/api/$apiVersion/" + val Libraries: String = ( List(Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository))) ++ PipPackages.map(p => Map("pypi" -> Map("package" -> p))) @@ -98,15 +100,15 @@ object DatabricksUtilities { val GPUNotebooks: Seq[File] = ParallelizableNotebooks.filter(_.getAbsolutePath.contains("Fine-tune")) - def databricksGet(path: String): JsValue = { - val request = new HttpGet(BaseURL + path) + def databricksGet(path: String, apiVersion: String = "2.0"): JsValue = { + val request = new HttpGet(baseURL(apiVersion) + path) request.addHeader("Authorization", AuthValue) RESTHelpers.sendAndParseJson(request) } //TODO convert all this to typed code - def databricksPost(path: String, body: String): JsValue = { - val request = new HttpPost(BaseURL + path) + def databricksPost(path: String, body: String, apiVersion: String = "2.0"): JsValue = { + val request = new HttpPost(baseURL(apiVersion) + path) request.addHeader("Authorization", AuthValue) request.setEntity(new StringEntity(body)) RESTHelpers.sendAndParseJson(request) @@ -120,7 +122,7 @@ object DatabricksUtilities { } def getPoolIdByName(name: String): String = { - val jsonObj = databricksGet("instance-pools/list") + val jsonObj = databricksGet("instance-pools/list", apiVersion = "2.0") val cluster = jsonObj.select[Array[JsValue]]("instance_pools") .filter(_.select[String]("instance_pool_name") == name).head cluster.select[String]("instance_pool_id") @@ -230,7 +232,7 @@ object DatabricksUtilities { () } - def submitRun(clusterId: String, notebookPath: String): Int = { + def submitRun(clusterId: String, notebookPath: String): Long = { val body = s""" |{ @@ -244,7 +246,7 @@ object DatabricksUtilities { | "libraries": $Libraries |} """.stripMargin - databricksPost("jobs/runs/submit", body).select[Int]("run_id") + databricksPost("jobs/runs/submit", body).select[Long]("run_id") } def isClusterActive(clusterId: String): Boolean = { @@ -265,7 +267,7 @@ object DatabricksUtilities { libraryStatuses.forall(_.select[String]("status") == "INSTALLED") } - private def getRunStatuses(runId: Int): (String, Option[String]) = { + private def getRunStatuses(runId: Long): (String, Option[String]) = { val runObj = databricksGet(s"jobs/runs/get?run_id=$runId") val stateObj = runObj.select[JsObject]("state") val lifeCycleState = stateObj.select[String]("life_cycle_state") @@ -277,7 +279,7 @@ object DatabricksUtilities { } } - def getRunUrlAndNBName(runId: Int): (String, String) = { + def getRunUrlAndNBName(runId: Long): (String, String) = { val runObj = databricksGet(s"jobs/runs/get?run_id=$runId").asJsObject() val url = runObj.select[String]("run_page_url") .replaceAll("westus", Region) //TODO this seems like an ADB bug @@ -286,7 +288,7 @@ object DatabricksUtilities { } //scalastyle:off cyclomatic.complexity - def monitorJob(runId: Integer, + def monitorJob(runId: Long, timeout: Int, interval: Int = 8000, logLevel: Int = 1): Future[Unit] = { @@ -342,28 +344,28 @@ object DatabricksUtilities { workspaceMkDir(folderToCreate) val destination: String = folderToCreate + notebookFile.getName uploadNotebook(notebookFile, destination) - val runId: Int = submitRun(clusterId, destination) + val runId: Long = submitRun(clusterId, destination) val run: DatabricksNotebookRun = DatabricksNotebookRun(runId, notebookFile.getName) println(s"Successfully submitted job run id ${run.runId} for notebook ${run.notebookName}") run } - def cancelRun(runId: Int): Unit = { + def cancelRun(runId: Long): Unit = { println(s"Cancelling job $runId") databricksPost("jobs/runs/cancel", s"""{"run_id": $runId}""") () } - def listActiveJobs(clusterId: String): Vector[Int] = { + def listActiveJobs(clusterId: String): Vector[Long] = { //TODO this only gets the first 1k running jobs, full solution would page results databricksGet("jobs/runs/list?active_only=true&limit=1000") .asJsObject.fields.get("runs").map { runs => - runs.asInstanceOf[JsArray].elements.flatMap { - case run if clusterId == run.select[String]("cluster_instance.cluster_id") => - Some(run.select[Int]("run_id")) - case _ => None - } - }.getOrElse(Array().toVector: Vector[Int]) + runs.asInstanceOf[JsArray].elements.flatMap { + case run if clusterId == run.select[String]("cluster_instance.cluster_id") => + Some(run.select[Long]("run_id")) + case _ => None + } + }.getOrElse(Array().toVector: Vector[Long]) } def listInstalledLibraries(clusterId: String): Vector[JsValue] = { @@ -400,8 +402,8 @@ abstract class DatabricksTestHelper extends TestBase { def databricksTestHelper(clusterId: String, libraries: String, - notebooks: Seq[File]): mutable.ListBuffer[Int] = { - val jobIdsToCancel: mutable.ListBuffer[Int] = mutable.ListBuffer[Int]() + notebooks: Seq[File]): mutable.ListBuffer[Long] = { + val jobIdsToCancel: mutable.ListBuffer[Long] = mutable.ListBuffer[Long]() println("Checking if cluster is active") tryWithRetries(Seq.fill(60 * 15)(1000).toArray) { () => @@ -437,7 +439,7 @@ abstract class DatabricksTestHelper extends TestBase { jobIdsToCancel } - protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Int], + protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Long], clusterId: String, clusterName: String): Unit = { println("Suite test finished. Running afterAll procedure...") @@ -447,7 +449,7 @@ abstract class DatabricksTestHelper extends TestBase { } } -case class DatabricksNotebookRun(runId: Int, notebookName: String) { +case class DatabricksNotebookRun(runId: Long, notebookName: String) { def monitor(logLevel: Int = 2): Future[Any] = { monitorJob(runId, TimeoutInMillis, logLevel) }