Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1644] Optimize handle merged data on stage end #2806

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1391,10 +1391,15 @@ public void onSuccess(ByteBuffer response) {
groupedBatchId,
Arrays.toString(batchIds));
pushState.removeBatch(groupedBatchId, hostPort);
if (response.remaining() > 0 && response.get() == StatusCode.MAP_ENDED.getValue()) {
mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapId);
if (response.remaining() > 0) {
int retCode = response.get();
if (retCode == StatusCode.MAP_ENDED.getValue()) {
Copy link
Member

@SteNicholas SteNicholas Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any compatibility problem between old client version and new server version, which situation is that server returns STAGE_ENDED but client only handles MAP_ENDED?

Copy link
Contributor Author

@FMX FMX Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

STAGE_ENDED will be returned if the field stageEnd is true. So old servers won't return STAGE_ENDED.

mapperEndMap
.computeIfAbsent(shuffleId, (id) -> ConcurrentHashMap.newKeySet())
.add(mapId);
} else if (retCode == StatusCode.STAGE_ENDED.getValue()) {
stageEndShuffleSet.add(shuffleId);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ abstract class CommitHandler(
shuffleId: Int,
shuffleCommittedInfo: ShuffleCommittedInfo,
params: ArrayBuffer[CommitFilesParam],
commitFilesFailedWorkers: ShuffleFailedWorkers): Unit = {
commitFilesFailedWorkers: ShuffleFailedWorkers,
isStageEnd: Boolean = false): Unit = {

def retryCommitFiles(status: CommitFutureWithStatus, currentTime: Long): Unit = {
status.retriedTimes = status.retriedTimes + 1
Expand Down Expand Up @@ -290,7 +291,8 @@ abstract class CommitHandler(
param.replicaIds,
getMapperAttempts(shuffleId),
commitEpoch.incrementAndGet(),
mockCommitFilesFailure)
mockCommitFilesFailure,
isStageEnd)
val future = commitFiles(param.worker, msg)

futures.add(CommitFutureWithStatus(future, msg, param.worker, 1, startTime))
Expand All @@ -311,13 +313,14 @@ abstract class CommitHandler(
while (iter.hasNext) {
val status = iter.next()
val worker = status.workerInfo
val workerAddr = worker.readableAddress()
if (status.future.isCompleted) {
status.future.value.get match {
case scala.util.Success(res) =>
res.status match {
case StatusCode.SUCCESS | StatusCode.PARTIAL_SUCCESS | StatusCode.SHUFFLE_NOT_REGISTERED | StatusCode.REQUEST_FAILED | StatusCode.WORKER_EXCLUDED =>
logInfo(s"Request commitFiles return ${res.status} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)}")
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} from ${workerAddr}")
if (res.status != StatusCode.SUCCESS && res.status != StatusCode.WORKER_EXCLUDED) {
commitFilesFailedWorkers.put(worker, (res.status, System.currentTimeMillis()))
}
Expand All @@ -326,12 +329,12 @@ abstract class CommitHandler(
case StatusCode.COMMIT_FILES_MOCK_FAILURE =>
if (status.retriedTimes < maxRetries) {
logError(s"Request commitFiles return ${res.status} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} for ${status.retriedTimes}/$maxRetries, will retry")
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} from ${workerAddr} for ${status.retriedTimes}/$maxRetries, will retry")
retryCommitFiles(status, currentTime)
} else {
logError(
s"Request commitFiles return ${StatusCode.COMMIT_FILES_MOCK_FAILURE} for " +
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} for ${status.retriedTimes}/$maxRetries, will not retry")
s"${Utils.makeShuffleKey(appUniqueId, shuffleId)} from ${workerAddr} for ${status.retriedTimes}/$maxRetries, will not retry")
val res = createFailResponse(status)
processResponse(res, status.workerInfo)
iter.remove()
Expand Down Expand Up @@ -439,7 +442,7 @@ abstract class CommitHandler(
replicaIds)
}

doParallelCommitFiles(shuffleId, shuffleCommittedInfo, params, commitFilesFailedWorkers)
doParallelCommitFiles(shuffleId, shuffleCommittedInfo, params, commitFilesFailedWorkers, true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only the ReducePartitionCommitHandler can set stageEnd to true when call handleFinalCommitFiles method.


logInfo(s"Shuffle $shuffleId " +
s"commit files complete. File count ${shuffleCommittedInfo.currentShuffleFileCount.sum()} " +
Expand Down
1 change: 1 addition & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ message PbCommitFiles {
repeated int32 mapAttempts = 5;
int64 epoch = 6;
bool mockFailure = 7;
bool stageEnd = 8;
}

message PbCommitFilesResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,8 @@ object ControlMessages extends Logging {
replicaIds: util.List[String],
mapAttempts: Array[Int],
epoch: Long,
var mockFailure: Boolean = false)
var mockFailure: Boolean = false,
val isStageEnd: Boolean)
extends WorkerMessage

case class CommitFilesResponse(
Expand Down Expand Up @@ -893,7 +894,8 @@ object ControlMessages extends Logging {
replicaIds,
mapAttempts,
epoch,
mockFailure) =>
mockFailure,
isStageEnd) =>
val payload = PbCommitFiles.newBuilder()
.setApplicationId(applicationId)
.setShuffleId(shuffleId)
Expand All @@ -902,6 +904,7 @@ object ControlMessages extends Logging {
.addAllMapAttempts(mapAttempts.map(Integer.valueOf).toIterable.asJava)
.setEpoch(epoch)
.setMockFailure(mockFailure)
.setStageEnd(isStageEnd)
.build().toByteArray
new TransportMessage(MessageType.COMMIT_FILES, payload)

Expand Down Expand Up @@ -1269,7 +1272,8 @@ object ControlMessages extends Logging {
pbCommitFiles.getReplicaIdsList,
pbCommitFiles.getMapAttemptsList.asScala.map(_.toInt).toArray,
pbCommitFiles.getEpoch,
pbCommitFiles.getMockFailure)
pbCommitFiles.getMockFailure,
pbCommitFiles.getStageEnd)

case COMMIT_FILES_RESPONSE_VALUE =>
val pbCommitFilesResponse = PbCommitFilesResponse.parseFrom(message.getPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ private[deploy] class Controller(

var storageManager: StorageManager = _
var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
var presumptiveEndedShuffles: ConcurrentHashMap.KeySetView[String, java.lang.Boolean] = _
// shuffleKey -> (epoch -> CommitInfo)
var shuffleCommitInfos: ConcurrentHashMap[String, ConcurrentHashMap[Long, CommitInfo]] = _
var shufflePartitionType: ConcurrentHashMap[String, PartitionType] = _
Expand All @@ -68,6 +69,7 @@ private[deploy] class Controller(
shufflePartitionType = worker.shufflePartitionType
shufflePushDataTimeout = worker.shufflePushDataTimeout
shuffleMapperAttempts = worker.shuffleMapperAttempts
presumptiveEndedShuffles = worker.presumptiveEndedShuffles
shuffleCommitInfos = worker.shuffleCommitInfos
workerInfo = worker.workerInfo
partitionLocationInfo = worker.partitionLocationInfo
Expand Down Expand Up @@ -122,7 +124,8 @@ private[deploy] class Controller(
replicaIds,
mapAttempts,
epoch,
mockFailure) =>
mockFailure,
isStageEnd) =>
checkAuth(context, applicationId)
val shuffleKey = Utils.makeShuffleKey(applicationId, shuffleId)
logDebug(s"Received CommitFiles request, $shuffleKey, primary files" +
Expand All @@ -135,7 +138,8 @@ private[deploy] class Controller(
replicaIds,
mapAttempts,
epoch,
mockFailure)
mockFailure,
isStageEnd)
})
logDebug(s"Done processed CommitFiles request with shuffleKey $shuffleKey, in " +
s"$commitFilesTimeMs ms.")
Expand Down Expand Up @@ -373,7 +377,8 @@ private[deploy] class Controller(
replicaIds: jList[String],
mapAttempts: Array[Int],
epoch: Long,
mockFailure: Boolean): Unit = {
mockFailure: Boolean,
isStageEnd: Boolean): Unit = {
if (mockFailure) {
logError(s"Mock commit files failure for Shuffle $shuffleKey!")
context.reply(
Expand Down Expand Up @@ -411,6 +416,9 @@ private[deploy] class Controller(

val shuffleCommitTimeout = conf.workerShuffleCommitTimeout

if (isStageEnd) {
presumptiveEndedShuffles.add(shuffleKey)
}
shuffleCommitInfos.putIfAbsent(shuffleKey, JavaUtils.newConcurrentHashMap[Long, CommitInfo]())
val epochCommitMap = shuffleCommitInfos.get(shuffleKey)
epochCommitMap.putIfAbsent(epoch, new CommitInfo(null, CommitInfo.COMMIT_NOTSTARTED))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler

private var partitionLocationInfo: WorkerPartitionLocationInfo = _
private var shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] = _
private var presumptiveEndedShuffles: ConcurrentHashMap.KeySetView[String, java.lang.Boolean] = _
private var shufflePartitionType: ConcurrentHashMap[String, PartitionType] = _
private var shufflePushDataTimeout: ConcurrentHashMap[String, Long] = _
private var replicateThreadPool: ThreadPoolExecutor = _
Expand Down Expand Up @@ -79,6 +80,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
replicateThreadPool = worker.replicateThreadPool
unavailablePeers = worker.unavailablePeers
replicateClientFactory = worker.replicateClientFactory
presumptiveEndedShuffles = worker.presumptiveEndedShuffles
registered = Some(worker.registered)
workerInfo = worker.workerInfo
diskReserveSize = worker.conf.workerDiskReserveSize
Expand Down Expand Up @@ -475,7 +477,13 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
// A shuffle can trigger multiple CommitFiles requests, for reasons like: HARD_SPLIT happens, StageEnd.
// If MapperAttempts but the value is -1 for the mapId(-1 means the map has not yet finished),
// it's probably because commitFiles for HARD_SPLIT happens.
if (shuffleMapperAttempts.containsKey(shuffleKey)) {
if (presumptiveEndedShuffles.contains(shuffleKey)) {
logDebug(s"Receive push merged data from speculative " +
s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId), " +
s"but this stage is ended.")
callbackWithTimer.onSuccess(
ByteBuffer.wrap(Array[Byte](StatusCode.STAGE_ENDED.getValue)))
} else if (shuffleMapperAttempts.containsKey(shuffleKey)) {
if (-1 != shuffleMapperAttempts.get(shuffleKey).get(mapId)) {
logDebug(s"Receive push merged data from speculative " +
s"task(shuffle $shuffleKey, map $mapId, attempt $attemptId), " +
Expand All @@ -490,6 +498,7 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler
ByteBuffer.wrap(Array[Byte](StatusCode.HARD_SPLIT.getValue)))
}
} else {
// This means that this stage is ended and invoked commit files by stage end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should remove this line~

if (storageManager.shuffleKeySet().contains(shuffleKey)) {
// If there is no shuffle key in shuffleMapperAttempts but there is shuffle key
// in StorageManager. This partition should be HARD_SPLIT partition and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ private[celeborn] class Worker(
val registered = new AtomicBoolean(false)
val shuffleMapperAttempts: ConcurrentHashMap[String, AtomicIntegerArray] =
JavaUtils.newConcurrentHashMap[String, AtomicIntegerArray]()
val presumptiveEndedShuffles: ConcurrentHashMap.KeySetView[String, java.lang.Boolean] =
ConcurrentHashMap.newKeySet[String]()
val shufflePartitionType: ConcurrentHashMap[String, PartitionType] =
JavaUtils.newConcurrentHashMap[String, PartitionType]
var shufflePushDataTimeout: ConcurrentHashMap[String, Long] =
Expand Down Expand Up @@ -742,6 +744,7 @@ private[celeborn] class Worker(
shufflePartitionType.remove(shuffleKey)
shufflePushDataTimeout.remove(shuffleKey)
shuffleMapperAttempts.remove(shuffleKey)
presumptiveEndedShuffles.remove(shuffleKey)
shuffleCommitInfos.remove(shuffleKey)
workerInfo.releaseSlots(shuffleKey)
val applicationId = Utils.splitShuffleKey(shuffleKey)._1
Expand Down
Loading