Skip to content

Commit

Permalink
add some test & format
Browse files Browse the repository at this point in the history
  • Loading branch information
zaynt4606 committed Oct 10, 2024
1 parent 64f4771 commit 09b4226
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite
.set(CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key, "3")
.set(CelebornConf.CLIENT_SLOT_ASSIGN_MAX_WORKERS.key, "1")
.set(CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key, "0")
.set(CelebornConf.CLIENT_BATCH_HANDLE_CHANGE_PARTITION_ENABLED.key, "false")
.set(CelebornConf.TEST_CLIENT_UPDATE_AVAILABLE_WORKER.key, "true")

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -95,11 +94,7 @@ class ChangePartitionManagerUpdateWorkersSuite extends WithShuffleClientSuite
lifecycleManager.shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
}
assert(lifecycleManager.workerSnapshots(shuffleId).size() == 1)
lifecycleManager.workerSnapshots(shuffleId).forEach {
case (workerInfo, partitionLocationInfo) =>
logInfo(s"worker: ${workerInfo}; partitionLocationInfo size: ${partitionLocationInfo.getPrimaryPartitions().size()}")
}
ids.forEach { partitionId =>
ids.forEach { partitionId: Integer =>
val req = ChangePartitionRequest(
null,
shuffleId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,20 @@ class RetryReviveTest extends AnyFunSuite
assert(result.size == 1000)
ss.stop()
}

test("celeborn spark integration test - retry revive with available workers from heartbeat") {
val sparkConf = new SparkConf()
.set(s"spark.${CelebornConf.TEST_CLIENT_RETRY_REVIVE.key}", "true")
.set(s"spark.${CelebornConf.CLIENT_PUSH_MAX_REVIVE_TIMES.key}", "3")
.set(s"spark.${CelebornConf.CLIENT_SLOT_ASSIGN_MAX_WORKERS.key}", "1")
.set(s"spark.${CelebornConf.MASTER_SLOT_ASSIGN_EXTRA_SLOTS.key}", "0")
.setAppName("celeborn-demo").setMaster("local[2]")
val ss = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
.getOrCreate()
val result = ss.sparkContext.parallelize(1 to 1000, 2)
.map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(4).collect()
assert(result.size == 1000)
ss.stop()
}
}

0 comments on commit 09b4226

Please sign in to comment.