Skip to content

Commit

Permalink
[CELEBORN-1496] Differentiate map results with only different stageAt…
Browse files Browse the repository at this point in the history
…temptId

### What changes were proposed in this pull request?
Let attemptNumber = (stageAttemptId << 16) | taskAttemptNumber, to differentiate map results with only different stageAttemptId.

### Why are the changes needed?
If we can't differentiate map tasks with only different stageAttemptId, it may lead to mixed reading of two map tasks' shuffle write batches during shuffle read, causing data correctness issue.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Add ut: org.apache.spark.shuffle.celeborn.SparkShuffleManagerSuite#testWrongSparkConf_MaxAttemptLimit

Closes apache#2609 from jiang13021/spark_stage_attempt_id.

Lead-authored-by: jiang13021 <[email protected]>
Co-authored-by: Fu Chen <[email protected]>
Co-authored-by: Shuang <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
3 people committed Aug 30, 2024
1 parent 3ee672e commit 3853075
Show file tree
Hide file tree
Showing 12 changed files with 160 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle.celeborn;

import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.scheduler.DAGScheduler;

public class SparkCommonUtils {
public static void validateAttemptConfig(SparkConf conf) throws IllegalArgumentException {
int maxStageAttempts =
conf.getInt(
"spark.stage.maxConsecutiveAttempts",
DAGScheduler.DEFAULT_MAX_CONSECUTIVE_STAGE_ATTEMPTS());
// In Spark 2, the parameter is referred to as MAX_TASK_FAILURES, while in Spark 3, it has been
// changed to TASK_MAX_FAILURES. The default value for both is consistently set to 4.
int maxTaskAttempts = conf.getInt("spark.task.maxFailures", 4);
if (maxStageAttempts >= (1 << 15) || maxTaskAttempts >= (1 << 16)) {
// The map attemptId is a non-negative number constructed from
// both stageAttemptNumber and taskAttemptNumber.
// The high 16 bits of the map attemptId are used for the stageAttemptNumber,
// and the low 16 bits are used for the taskAttemptNumber.
// So spark.stage.maxConsecutiveAttempts should be less than 32768 (1 << 15)
// and spark.task.maxFailures should be less than 65536 (1 << 16).
throw new IllegalArgumentException(
"The spark.stage.maxConsecutiveAttempts should be less than 32768 (currently "
+ maxStageAttempts
+ ")"
+ "and spark.task.maxFailures should be less than 65536 (currently "
+ maxTaskAttempts
+ ").");
}
}

public static int getEncodedAttemptNumber(TaskContext context) {
return (context.stageAttemptNumber() << 16) | context.attemptNumber();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -112,6 +113,7 @@ public HashBasedShuffleWriter(
this.mapId = mapId;
this.dep = handle.dependency();
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
Expand Down Expand Up @@ -146,7 +148,7 @@ public HashBasedShuffleWriter(
new DataPusher(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -279,7 +281,7 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand Down Expand Up @@ -333,7 +335,7 @@ private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by dataPusher thread
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
shuffleClient.prepareForMergeData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);

// merge and push residual data to reduce network traffic
// NB: since dataPusher thread have no in-flight data at this point,
Expand All @@ -345,7 +347,7 @@ private void close() throws IOException, InterruptedException {
shuffleClient.mergeData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
i,
sendBuffers[i],
0,
Expand All @@ -358,7 +360,7 @@ private void close() throws IOException, InterruptedException {
writeMetrics.incBytesWritten(bytesWritten);
}
}
shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);

updateMapStatus();

Expand All @@ -367,7 +369,7 @@ private void close() throws IOException, InterruptedException {
sendOffsets = null;

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
Expand Down Expand Up @@ -404,7 +406,7 @@ public Option<MapStatus> stop(boolean success) {
}
}
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public class SortBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetrics writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -102,6 +103,7 @@ public SortBasedShuffleWriter(
this.mapId = taskContext.partitionId();
this.dep = dep;
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
Expand Down Expand Up @@ -130,7 +132,7 @@ public SortBasedShuffleWriter(
taskContext,
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -285,7 +287,7 @@ private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) throw
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand All @@ -309,12 +311,12 @@ private void close() throws IOException {
pusher.close(true);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);

shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);

updateMapStatus();

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);
}

Expand Down Expand Up @@ -350,7 +352,7 @@ public Option<MapStatus> stop(boolean success) {
} catch (IOException e) {
return Option.apply(null);
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public class SparkShuffleManager implements ShuffleManager {
private ExecutorShuffleIdTracker shuffleIdTracker = new ExecutorShuffleIdTracker();

public SparkShuffleManager(SparkConf conf, boolean isDriver) {
SparkCommonUtils.validateAttemptConfig(conf);
this.conf = conf;
this.isDriver = isDriver;
this.celebornConf = SparkUtils.fromSparkConf(conf);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CelebornShuffleReader[K, C](
handle.extension)

private val exceptionRef = new AtomicReference[IOException]
private val encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(context)

override def read(): Iterator[Product2[K, C]] = {

Expand Down Expand Up @@ -96,7 +97,7 @@ class CelebornShuffleReader[K, C](
val inputStream = shuffleClient.readPartition(
shuffleId,
partitionId,
context.attemptNumber(),
encodedAttemptId,
startMapIndex,
endMapIndex,
metricsCallback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.celeborn
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.internal.SQLConf
import org.junit
import org.junit.Assert
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

Expand Down Expand Up @@ -67,4 +68,33 @@ class SparkShuffleManagerSuite extends Logging {
sc.stop()
}

@junit.Test
def testWrongSparkConfMaxAttemptLimit(): Unit = {
val conf = new SparkConf().setIfMissing("spark.master", "local")
.setIfMissing(
"spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.SparkShuffleManager")
.set(s"spark.${CelebornConf.MASTER_ENDPOINTS.key}", "localhost:9097")
.set(s"spark.${CelebornConf.CLIENT_PUSH_REPLICATE_ENABLED.key}", "false")
.set("spark.shuffle.service.enabled", "false")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")

// default conf, will success
new SparkShuffleManager(conf, true)

conf
.set("spark.stage.maxConsecutiveAttempts", "32768")
.set("spark.task.maxFailures", "10")
try {
new SparkShuffleManager(conf, true)
Assert.fail()
} catch {
case e: IllegalArgumentException =>
Assert.assertTrue(
e.getMessage.contains("The spark.stage.maxConsecutiveAttempts should be less than 32768"))
case _: Throwable =>
Assert.fail()
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.shuffle.celeborn

import org.apache.spark.{ShuffleDependency, SparkConf}
import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext}
import org.apache.spark.serializer.{KryoSerializer, KryoSerializerInstance}
import org.apache.spark.sql.execution.UnsafeRowSerializer
import org.apache.spark.sql.execution.columnar.CelebornColumnarBatchSerializerInstance
Expand Down Expand Up @@ -45,14 +45,17 @@ class CelebornColumnarShuffleReaderSuite {

var shuffleClient: MockedStatic[ShuffleClient] = null
try {
val taskContext = Mockito.mock(classOf[TaskContext])
Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
handle,
0,
10,
0,
10,
null,
taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
Expand All @@ -68,6 +71,9 @@ class CelebornColumnarShuffleReaderSuite {
def columnarShuffleReaderNewSerializerInstance(): Unit = {
var shuffleClient: MockedStatic[ShuffleClient] = null
try {
val taskContext = Mockito.mock(classOf[TaskContext])
Mockito.when(taskContext.stageAttemptNumber).thenReturn(0)
Mockito.when(taskContext.attemptNumber).thenReturn(0)
shuffleClient = Mockito.mockStatic(classOf[ShuffleClient])
val shuffleReader = SparkUtils.createColumnarShuffleReader(
new CelebornShuffleHandle[Int, String, String](
Expand All @@ -83,7 +89,7 @@ class CelebornColumnarShuffleReaderSuite {
10,
0,
10,
null,
taskContext,
new CelebornConf(),
null,
new ExecutorShuffleIdTracker())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public class HashBasedShuffleWriter<K, V, C> extends ShuffleWriter<K, V> {
private final ShuffleWriteMetricsReporter writeMetrics;
private final int shuffleId;
private final int mapId;
private final int encodedAttemptId;
private final TaskContext taskContext;
private final ShuffleClient shuffleClient;
private final int numMappers;
Expand Down Expand Up @@ -112,6 +113,7 @@ public HashBasedShuffleWriter(
this.mapId = taskContext.partitionId();
this.dep = handle.dependency();
this.shuffleId = shuffleId;
this.encodedAttemptId = SparkCommonUtils.getEncodedAttemptNumber(taskContext);
SerializerInstance serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = metrics;
Expand Down Expand Up @@ -142,7 +144,7 @@ public HashBasedShuffleWriter(
new DataPusher(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
taskContext.taskAttemptId(),
numMappers,
numPartitions,
Expand Down Expand Up @@ -279,7 +281,7 @@ protected void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) thr
shuffleClient.pushData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
0,
Expand Down Expand Up @@ -343,7 +345,7 @@ protected void mergeData(int partitionId, byte[] buffer, int offset, int length)
shuffleClient.mergeData(
shuffleId,
mapId,
taskContext.attemptNumber(),
encodedAttemptId,
partitionId,
buffer,
offset,
Expand All @@ -368,14 +370,14 @@ private void close() throws IOException, InterruptedException {
long pushMergedDataTime = System.nanoTime();
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
shuffleClient.prepareForMergeData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.prepareForMergeData(shuffleId, mapId, encodedAttemptId);
closeWrite();
shuffleClient.pushMergedData(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushMergedDataTime);
updateRecordsWrittenMetrics();

long waitStartTime = System.nanoTime();
shuffleClient.mapperEnd(shuffleId, mapId, taskContext.attemptNumber(), numMappers);
shuffleClient.mapperEnd(shuffleId, mapId, encodedAttemptId, numMappers);
writeMetrics.incWriteTime(System.nanoTime() - waitStartTime);

BlockManagerId bmId = SparkEnv.get().blockManager().shuffleServerId();
Expand Down Expand Up @@ -408,7 +410,7 @@ public Option<MapStatus> stop(boolean success) {
}
}
} finally {
shuffleClient.cleanup(shuffleId, mapId, taskContext.attemptNumber());
shuffleClient.cleanup(shuffleId, mapId, encodedAttemptId);
}
}

Expand Down
Loading

0 comments on commit 3853075

Please sign in to comment.