diff --git a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java index b2c0d8b4469..284ba3148ae 100644 --- a/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java +++ b/client/src/test/java/org/apache/celeborn/client/write/DataPushQueueSuiteJ.java @@ -20,10 +20,10 @@ import java.io.File; import java.io.IOException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.LongAdder; import org.junit.AfterClass; @@ -45,7 +45,7 @@ public class DataPushQueueSuiteJ { private static File tempDir = null; private final int shuffleId = 0; - private final int numPartitions = 10; + private final int numPartitions = 1000000; @BeforeClass public static void beforeAll() { @@ -63,7 +63,7 @@ public static void afterAll() { @Test public void testDataPushQueue() throws Exception { - final int numWorker = 3; + final int numWorker = 30; List> workerData = new ArrayList<>(); for (int i = 0; i < numWorker; i++) { workerData.add(new ArrayList<>()); @@ -76,7 +76,7 @@ public void testDataPushQueue() throws Exception { tarWorkerData.add(new ArrayList<>()); } - Map partitionBatchIdMap = new HashMap<>(); + Map partitionBatchIdMap = new ConcurrentHashMap<>(); CelebornConf conf = new CelebornConf(); conf.set(CelebornConf.CLIENT_PUSH_MAX_REQS_IN_FLIGHT_PERWORKER().key(), "2"); @@ -85,7 +85,6 @@ public void testDataPushQueue() throws Exception { int mapId = 0; int attemptId = 0; int numMappers = 10; - int numPartitions = 10; final File tempFile = new File(tempDir, UUID.randomUUID().toString()); DummyShuffleClient client = new DummyShuffleClient(conf, tempFile); client.initReducePartitionMap(shuffleId, numPartitions, numWorker); @@ -125,10 +124,10 @@ protected void pushData(PushTask task) throws IOException { for (int i = 0; i < numPartitions; i++) { byte[] b = intToBytes(workerData.get(i % numWorker).get(i / numWorker)); - dataPusher.addTask(i, b, b.length); int batchId = pushState.nextBatchId(); pushState.addBatch(batchId, reducePartitionMap.get(i).hostAndPushPort()); partitionBatchIdMap.put(i, batchId); + dataPusher.addTask(i, b, b.length); } dataPusher.waitOnTermination();