Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle long run ids in adb tests #2102

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import scala.language.existentials
class DatabricksCPUTests extends DatabricksTestHelper {

val clusterId: String = createClusterInPool(ClusterName, AdbRuntime, NumWorkers, PoolId, "[]")
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(clusterId, Libraries, CPUNotebooks)
val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper(clusterId, Libraries, CPUNotebooks)

protected override def afterAll(): Unit = {
afterAllHelper(jobIdsToCancel, clusterId, ClusterName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class DatabricksGPUTests extends DatabricksTestHelper {
"src", "main", "python", "horovod_installation.sh").getCanonicalFile
uploadFileToDBFS(horovodInstallationScript, "/FileStore/horovod-fix-commit/horovod_installation.sh")
val clusterId: String = createClusterInPool(GPUClusterName, AdbGpuRuntime, 2, GpuPoolId, GPUInitScripts)
val jobIdsToCancel: ListBuffer[Int] = databricksTestHelper(
val jobIdsToCancel: ListBuffer[Long] = databricksTestHelper(
clusterId, GPULibraries, GPUNotebooks)

protected override def afterAll(): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object DatabricksUtilities {
lazy val Token: String = sys.env.getOrElse("MML_ADB_TOKEN", Secrets.AdbToken)
lazy val AuthValue: String = "Basic " + BaseEncoding.base64()
.encode(("token:" + Token).getBytes("UTF-8"))
val BaseURL = s"https://$Region.azuredatabricks.net/api/2.0/"

lazy val PoolId: String = getPoolIdByName(PoolName)
lazy val GpuPoolId: String = getPoolIdByName(GpuPoolName)
lazy val ClusterName = s"mmlspark-build-${LocalDateTime.now()}"
Expand Down Expand Up @@ -67,6 +67,8 @@ object DatabricksUtilities {
"interpret-community"
)

def baseURL(apiVersion: String): String = s"https://$Region.azuredatabricks.net/api/$apiVersion/"

val Libraries: String = (
List(Map("maven" -> Map("coordinates" -> PackageMavenCoordinate, "repo" -> PackageRepository))) ++
PipPackages.map(p => Map("pypi" -> Map("package" -> p)))
Expand Down Expand Up @@ -98,15 +100,15 @@ object DatabricksUtilities {

val GPUNotebooks: Seq[File] = ParallelizableNotebooks.filter(_.getAbsolutePath.contains("Fine-tune"))

def databricksGet(path: String): JsValue = {
val request = new HttpGet(BaseURL + path)
def databricksGet(path: String, apiVersion: String = "2.0"): JsValue = {
val request = new HttpGet(baseURL(apiVersion) + path)
request.addHeader("Authorization", AuthValue)
RESTHelpers.sendAndParseJson(request)
}

//TODO convert all this to typed code
def databricksPost(path: String, body: String): JsValue = {
val request = new HttpPost(BaseURL + path)
def databricksPost(path: String, body: String, apiVersion: String = "2.0"): JsValue = {
val request = new HttpPost(baseURL(apiVersion) + path)
request.addHeader("Authorization", AuthValue)
request.setEntity(new StringEntity(body))
RESTHelpers.sendAndParseJson(request)
Expand All @@ -120,7 +122,7 @@ object DatabricksUtilities {
}

def getPoolIdByName(name: String): String = {
val jsonObj = databricksGet("instance-pools/list")
val jsonObj = databricksGet("instance-pools/list", apiVersion = "2.0")
val cluster = jsonObj.select[Array[JsValue]]("instance_pools")
.filter(_.select[String]("instance_pool_name") == name).head
cluster.select[String]("instance_pool_id")
Expand Down Expand Up @@ -230,7 +232,7 @@ object DatabricksUtilities {
()
}

def submitRun(clusterId: String, notebookPath: String): Int = {
def submitRun(clusterId: String, notebookPath: String): Long = {
val body =
s"""
|{
Expand All @@ -244,7 +246,7 @@ object DatabricksUtilities {
| "libraries": $Libraries
|}
""".stripMargin
databricksPost("jobs/runs/submit", body).select[Int]("run_id")
databricksPost("jobs/runs/submit", body).select[Long]("run_id")
}

def isClusterActive(clusterId: String): Boolean = {
Expand All @@ -265,7 +267,7 @@ object DatabricksUtilities {
libraryStatuses.forall(_.select[String]("status") == "INSTALLED")
}

private def getRunStatuses(runId: Int): (String, Option[String]) = {
private def getRunStatuses(runId: Long): (String, Option[String]) = {
val runObj = databricksGet(s"jobs/runs/get?run_id=$runId")
val stateObj = runObj.select[JsObject]("state")
val lifeCycleState = stateObj.select[String]("life_cycle_state")
Expand All @@ -277,7 +279,7 @@ object DatabricksUtilities {
}
}

def getRunUrlAndNBName(runId: Int): (String, String) = {
def getRunUrlAndNBName(runId: Long): (String, String) = {
val runObj = databricksGet(s"jobs/runs/get?run_id=$runId").asJsObject()
val url = runObj.select[String]("run_page_url")
.replaceAll("westus", Region) //TODO this seems like an ADB bug
Expand All @@ -286,7 +288,7 @@ object DatabricksUtilities {
}

//scalastyle:off cyclomatic.complexity
def monitorJob(runId: Integer,
def monitorJob(runId: Long,
timeout: Int,
interval: Int = 8000,
logLevel: Int = 1): Future[Unit] = {
Expand Down Expand Up @@ -342,28 +344,28 @@ object DatabricksUtilities {
workspaceMkDir(folderToCreate)
val destination: String = folderToCreate + notebookFile.getName
uploadNotebook(notebookFile, destination)
val runId: Int = submitRun(clusterId, destination)
val runId: Long = submitRun(clusterId, destination)
val run: DatabricksNotebookRun = DatabricksNotebookRun(runId, notebookFile.getName)
println(s"Successfully submitted job run id ${run.runId} for notebook ${run.notebookName}")
run
}

def cancelRun(runId: Int): Unit = {
def cancelRun(runId: Long): Unit = {
println(s"Cancelling job $runId")
databricksPost("jobs/runs/cancel", s"""{"run_id": $runId}""")
()
}

def listActiveJobs(clusterId: String): Vector[Int] = {
def listActiveJobs(clusterId: String): Vector[Long] = {
//TODO this only gets the first 1k running jobs, full solution would page results
databricksGet("jobs/runs/list?active_only=true&limit=1000")
.asJsObject.fields.get("runs").map { runs =>
runs.asInstanceOf[JsArray].elements.flatMap {
case run if clusterId == run.select[String]("cluster_instance.cluster_id") =>
Some(run.select[Int]("run_id"))
case _ => None
}
}.getOrElse(Array().toVector: Vector[Int])
runs.asInstanceOf[JsArray].elements.flatMap {
case run if clusterId == run.select[String]("cluster_instance.cluster_id") =>
Some(run.select[Long]("run_id"))
case _ => None
}
}.getOrElse(Array().toVector: Vector[Long])
}

def listInstalledLibraries(clusterId: String): Vector[JsValue] = {
Expand Down Expand Up @@ -400,8 +402,8 @@ abstract class DatabricksTestHelper extends TestBase {

def databricksTestHelper(clusterId: String,
libraries: String,
notebooks: Seq[File]): mutable.ListBuffer[Int] = {
val jobIdsToCancel: mutable.ListBuffer[Int] = mutable.ListBuffer[Int]()
notebooks: Seq[File]): mutable.ListBuffer[Long] = {
val jobIdsToCancel: mutable.ListBuffer[Long] = mutable.ListBuffer[Long]()

println("Checking if cluster is active")
tryWithRetries(Seq.fill(60 * 15)(1000).toArray) { () =>
Expand Down Expand Up @@ -437,7 +439,7 @@ abstract class DatabricksTestHelper extends TestBase {
jobIdsToCancel
}

protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Int],
protected def afterAllHelper(jobIdsToCancel: mutable.ListBuffer[Long],
clusterId: String,
clusterName: String): Unit = {
println("Suite test finished. Running afterAll procedure...")
Expand All @@ -447,7 +449,7 @@ abstract class DatabricksTestHelper extends TestBase {
}
}

case class DatabricksNotebookRun(runId: Int, notebookName: String) {
case class DatabricksNotebookRun(runId: Long, notebookName: String) {
def monitor(logLevel: Int = 2): Future[Any] = {
monitorJob(runId, TimeoutInMillis, logLevel)
}
Expand Down
Loading