Skip to content

Commit

Permalink
[CELEBORN-1490][CIP-6] Enrich register shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 4, 2024
1 parent 40a6546 commit e69b51b
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ChangePartitionManager(
// shuffleId -> set of partition id
private val inBatchPartitions =
JavaUtils.newConcurrentHashMap[Int, ConcurrentHashMap.KeySetView[Int, java.lang.Boolean]]()
private val shuffleIsSegmentGranularityVisible = JavaUtils.newConcurrentHashMap[Int, Boolean]()

private val batchHandleChangePartitionEnabled = conf.batchHandleChangePartitionEnabled
private val batchHandleChangePartitionExecutors = ThreadUtils.newDaemonCachedThreadPool(
Expand Down Expand Up @@ -103,7 +104,8 @@ class ChangePartitionManager(
if (distinctPartitions.nonEmpty) {
handleRequestPartitions(
shuffleId,
distinctPartitions)
distinctPartitions,
shuffleIsSegmentGranularityVisible.get(shuffleId))
}
}
}
Expand Down Expand Up @@ -153,7 +155,8 @@ class ChangePartitionManager(
partitionId: Int,
oldEpoch: Int,
oldPartition: PartitionLocation,
cause: Option[StatusCode] = None): Unit = {
cause: Option[StatusCode] = None,
isSegmentGranularityVisible: Boolean): Unit = {

val changePartition = ChangePartitionRequest(
context,
Expand All @@ -165,6 +168,7 @@ class ChangePartitionManager(
// check if there exists request for the partition, if do just register
val requests = changePartitionRequests.computeIfAbsent(shuffleId, rpcContextRegisterFunc)
inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
shuffleIsSegmentGranularityVisible.put(shuffleId, isSegmentGranularityVisible)

lifecycleManager.commitManager.registerCommitPartitionRequest(
shuffleId,
Expand Down Expand Up @@ -195,7 +199,7 @@ class ChangePartitionManager(
}
}
if (!batchHandleChangePartitionEnabled) {
handleRequestPartitions(shuffleId, Array(changePartition))
handleRequestPartitions(shuffleId, Array(changePartition), isSegmentGranularityVisible)
}
}

Expand All @@ -215,7 +219,8 @@ class ChangePartitionManager(

def handleRequestPartitions(
shuffleId: Int,
changePartitions: Array[ChangePartitionRequest]): Unit = {
changePartitions: Array[ChangePartitionRequest],
isSegmentGranularityVisible: Boolean): Unit = {
val requestsMap = changePartitionRequests.get(shuffleId)

val changes = changePartitions.map { change =>
Expand Down Expand Up @@ -296,7 +301,8 @@ class ChangePartitionManager(
if (!lifecycleManager.reserveSlotsWithRetry(
shuffleId,
new util.HashSet(candidates.toSet.asJava),
newlyAllocatedLocations)) {
newlyAllocatedLocations,
isSegmentGranularityVisible = isSegmentGranularityVisible)) {
logError(s"[Update partition] failed for $shuffleId.")
replyFailure(StatusCode.RESERVE_SLOTS_FAILED)
return
Expand Down Expand Up @@ -324,6 +330,8 @@ class ChangePartitionManager(
s"shuffle $shuffleId, succeed partitions: " +
s"$changes.")
}

// todo: should record the new partition locations and acknowledge the new partitionLocations to downstream task, in scenario the downstream task start early before the upstream task
locations
}
replySuccess(newPrimaryLocations.toArray)
Expand All @@ -346,6 +354,7 @@ class ChangePartitionManager(
def removeExpiredShuffle(shuffleId: Int): Unit = {
changePartitionRequests.remove(shuffleId)
inBatchPartitions.remove(shuffleId)
shuffleIsSegmentGranularityVisible.remove(shuffleId)
locks.remove(shuffleId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
batchHandleCommitPartitionSchedulerThread.foreach(ThreadUtils.shutdown(_))
}

def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
committedPartitionInfo.put(
shuffleId,
ShuffleCommittedInfo(
Expand All @@ -191,7 +194,14 @@ class CommitManager(appUniqueId: String, val conf: CelebornConf, lifecycleManage
new AtomicInteger(),
JavaUtils.newConcurrentHashMap[Int, AtomicInteger]()))

getCommitHandler(shuffleId).registerShuffle(shuffleId, numMappers)
getCommitHandler(shuffleId).registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible);
}

def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
getCommitHandler(shuffleId).isSegmentGranularityVisible(shuffleId);
}

def isMapperEnded(shuffleId: Int, mapId: Int): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,17 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
val mapId = pb.getMapId
val attemptId = pb.getAttemptId
val partitionId = pb.getPartitionId
val isSegmentGranularityVisible = pb.getIsSegmentGranularityVisible
logDebug(s"Received Register map partition task request, " +
s"$shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
s"$shuffleId, $numMappers, $mapId, $attemptId, $partitionId, $isSegmentGranularityVisible.")
shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
offerAndReserveSlots(
RegisterCallContext(context, partitionId),
shuffleId,
numMappers,
numMappers,
partitionId)
partitionId,
isSegmentGranularityVisible)

case pb: PbRevive =>
val shuffleId = pb.getShuffleId
Expand Down Expand Up @@ -377,7 +379,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId,
partitionId,
epoch,
oldPartition)
oldPartition,
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))

case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) =>
logTrace(s"Received MapperEnd TaskEnd request, " +
Expand Down Expand Up @@ -496,7 +499,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId: Int,
numMappers: Int,
numPartitions: Int,
partitionId: Int = -1): Unit = {
partitionId: Int = -1,
isSegmentGranularityVisible: Boolean = false): Unit = {
val partitionType = getPartitionType(shuffleId)
registeringShuffleRequest.synchronized {
if (registeringShuffleRequest.containsKey(shuffleId)) {
Expand Down Expand Up @@ -575,7 +579,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId,
partitionId,
-1,
null)
null,
isSegmentGranularityVisible = commitManager.isSegmentGranularityVisible(shuffleId))
}
}

Expand Down Expand Up @@ -681,7 +686,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId,
candidatesWorkers,
slots,
updateEpoch = false)
updateEpoch = false,
isSegmentGranularityVisible)

// If reserve slots failed, clear allocated resources, reply ReserveSlotFailed and return.
if (!reserveSlotsSuccess) {
Expand All @@ -703,11 +709,18 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
registeredShuffle.add(shuffleId)
commitManager.registerShuffle(shuffleId, numMappers)
commitManager.registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible)

// Fifth, reply the allocated partition location to ShuffleClient.
logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
val allPrimaryPartitionLocations = slots.asScala.flatMap(_._2._1.asScala).toArray
commitManager.registerShuffle(
shuffleId,
numMappers,
isSegmentGranularityVisible)
replyRegisterShuffle(RegisterShuffleResponse(
StatusCode.SUCCESS,
allPrimaryPartitionLocations))
Expand Down Expand Up @@ -761,7 +774,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
partitionIds.get(idx),
oldEpochs.get(idx),
oldPartitions.get(idx),
Some(causes.get(idx)))
Some(causes.get(idx)),
commitManager.isSegmentGranularityVisible(shuffleId))
}
}

Expand Down Expand Up @@ -1083,7 +1097,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
*/
private def reserveSlots(
shuffleId: Int,
slots: WorkerResource): util.List[WorkerInfo] = {
slots: WorkerResource,
isSegmentGranularityVisible: Boolean = false): util.List[WorkerInfo] = {
val reserveSlotFailedWorkers = new ShuffleFailedWorkers()
val failureInfos = new util.concurrent.CopyOnWriteArrayList[String]()
val workerPartitionLocations = slots.asScala.filter(p => !p._2._1.isEmpty || !p._2._2.isEmpty)
Expand All @@ -1106,7 +1121,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
conf.pushDataTimeoutMs,
if (getPartitionType(shuffleId) == PartitionType.MAP)
conf.clientShuffleMapPartitionSplitEnabled
else true))
else true,
isSegmentGranularityVisible))
futures.add((future, workerInfo))
}(ec)
}
Expand Down Expand Up @@ -1303,7 +1319,8 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
shuffleId: Int,
candidates: util.HashSet[WorkerInfo],
slots: WorkerResource,
updateEpoch: Boolean = true): Boolean = {
updateEpoch: Boolean = true,
isSegmentGranularityVisible: Boolean = false): Boolean = {
var requestSlots = slots
val reserveSlotsMaxRetries = conf.clientReserveSlotsMaxRetries
val reserveSlotsRetryWait = conf.clientReserveSlotsRetryWait
Expand All @@ -1316,7 +1333,7 @@ class LifecycleManager(val appUniqueId: String, val conf: CelebornConf) extends
}
// reserve buffers
logInfo(s"Try reserve slots for $shuffleId for $retryTimes times.")
val reserveFailedWorkers = reserveSlots(shuffleId, requestSlots)
val reserveFailedWorkers = reserveSlots(shuffleId, requestSlots, isSegmentGranularityVisible)
if (reserveFailedWorkers.isEmpty) {
success = true
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,20 @@ abstract class CommitHandler(
partitionId: Int,
recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)

def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
// TODO: if isSegmentGranularityVisible is set to true, it is necessary to handle the pending
// get partition request of downstream reduce task here, in scenarios which support
// downstream task start early before the upstream task, e.g. flink hybrid shuffle.
reducerFileGroupsMap.put(shuffleId, JavaUtils.newConcurrentHashMap())
}

def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
false
}

def doParallelCommitFiles(
shuffleId: Int,
shuffleCommittedInfo: ShuffleCommittedInfo,
Expand Down Expand Up @@ -463,7 +473,8 @@ abstract class CommitHandler(
primaryPartitionUniqueIds: util.Iterator[String],
replicaPartitionUniqueIds: util.Iterator[String],
primaryPartMap: ConcurrentHashMap[String, PartitionLocation],
replicaPartMap: ConcurrentHashMap[String, PartitionLocation]): Unit = {
replicaPartMap: ConcurrentHashMap[String, PartitionLocation],
isSegmentGranularityVisible: Boolean = false): Unit = {
val committedPartitions = new util.HashMap[String, PartitionLocation]
primaryPartitionUniqueIds.asScala.foreach { id =>
val partitionLocation = primaryPartMap.get(id)
Expand All @@ -488,6 +499,8 @@ abstract class CommitHandler(
}
}

// TODO: if support upstream task write and downstream task read simultaneously,
// should record the partition locations information in upstream task start time, rather than end time.
committedPartitions.values().asScala.foreach { partition =>
val partitionLocations = reducerFileGroupsMap.get(shuffleId).computeIfAbsent(
partition.getId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ class MapPartitionCommitHandler(
// shuffleId -> in processing partitionId set
private val inProcessMapPartitionEndIds = JavaUtils.newConcurrentHashMap[Int, util.Set[Integer]]()

// shuffleId -> boolean, records whether the shuffle is visible at the segment level, facilitating future optimization of worker read and write processes
private val shuffleIsSegmentGranularityVisible = JavaUtils.newConcurrentHashMap[Int, Boolean]

override def getPartitionType(): PartitionType = {
PartitionType.MAP
}
Expand Down Expand Up @@ -113,6 +116,7 @@ class MapPartitionCommitHandler(
override def removeExpiredShuffle(shuffleId: Int): Unit = {
inProcessMapPartitionEndIds.remove(shuffleId)
shuffleSucceedPartitionIds.remove(shuffleId)
shuffleIsSegmentGranularityVisible.remove(shuffleId)
super.removeExpiredShuffle(shuffleId)
}

Expand Down Expand Up @@ -143,7 +147,8 @@ class MapPartitionCommitHandler(
getPartitionUniqueIds(shuffleCommittedInfo.committedPrimaryIds, partitionId),
getPartitionUniqueIds(shuffleCommittedInfo.committedReplicaIds, partitionId),
parallelCommitResult.primaryPartitionLocationMap,
parallelCommitResult.replicaPartitionLocationMap)
parallelCommitResult.replicaPartitionLocationMap,
shuffleIsSegmentGranularityVisible.get(shuffleId))
}

(dataLost, parallelCommitResult.commitFilesFailedWorkers)
Expand Down Expand Up @@ -211,7 +216,23 @@ class MapPartitionCommitHandler(
(dataCommitSuccess, false)
}

override def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
shuffleIsSegmentGranularityVisible.put(shuffleId, isSegmentGranularityVisible)
}

override def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
shuffleIsSegmentGranularityVisible.get(shuffleId)
}

override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: Int): Unit = {
// TODO: if support the downstream map task start early before the upstream reduce task, it should
// waiting the upstream task register shuffle, then reply these GetReducerFileGroup.
// Note that flink hybrid shuffle should support it in the future.

// we need obtain the last succeed partitionIds
val lastSucceedPartitionIds =
shuffleSucceedPartitionIds.getOrDefault(shuffleId, new util.HashSet[Integer]())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,11 @@ class ReducePartitionCommitHandler(
}
}

override def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
super.registerShuffle(shuffleId, numMappers)
override def registerShuffle(
shuffleId: Int,
numMappers: Int,
isSegmentGranularityVisible: Boolean): Unit = {
super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
getReducerFileGroupRequest.put(shuffleId, new util.HashSet[RpcCallContext]())
initMapperAttempts(shuffleId, numMappers)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)

lifecycleManager.commitManager.registerShuffle(shuffleId, 1)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
0 until 10 foreach { partitionId =>
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
}
Expand Down Expand Up @@ -116,7 +116,7 @@ class LifecycleManagerCommitFilesSuite extends WithShuffleClientSuite with MiniC
res.workerResource,
updateEpoch = false)

lifecycleManager.commitManager.registerShuffle(shuffleId, 1)
lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
0 until 10 foreach { partitionId =>
lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, partitionId)
}
Expand Down

0 comments on commit e69b51b

Please sign in to comment.