diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 201c174e434b2..9cff602cca537 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -27,7 +27,11 @@ import java.util.LinkedList; import java.util.Map; import java.util.Optional; +import java.util.concurrent.BlockingQueue; import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; import java.util.zip.Checksum; import javax.annotation.Nullable; @@ -174,30 +178,60 @@ public void write(Iterator> records) throws Exception { // included in the shuffle write time. writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - final long starttttt = System.currentTimeMillis(); if (useMultiThreadedShuffle) { - final LinkedList> writeFutures = new LinkedList<>(); - try { - while (records.hasNext()) { - final Product2 record = records.next(); - final K key = record._1(); - Map.Entry entry = - partitionSlotWriters.get(partitioner.getPartition(key)); - final int slotNum = entry.getKey(); - final DiskBlockObjectWriter writer = entry.getValue(); - - writeFutures.add(SortShuffleManager.queueWriteTask(slotNum, () -> { - writer.write(key, record._2()); + long inflight = 0; + final Map>> + partitionRecords = new HashMap<>(numPartitions); + final BlockingQueue finQueue = new LinkedBlockingQueue<>(); + final BlockingQueue errQueue = new LinkedBlockingQueue<>(); + while (records.hasNext()) { + inflight++; + final Product2 record = records.next(); + final K key = record._1(); + final int partition = partitioner.getPartition(key); + partitionRecords.putIfAbsent(partition, new LinkedList<>()); + LinkedList> recordBatch = partitionRecords.get(partition); + if (recordBatch.size() > 500) { + partitionRecords.remove(partition); + final LinkedList> batch = recordBatch; + SortShuffleManager.queueWriteTask(partitionSlotWriters.get(partition).getKey(), + () -> { + try { + final DiskBlockObjectWriter writer = partitionSlotWriters.get(partition).getValue(); + batch.forEach(r -> writer.write(r._1(), r._2())); + finQueue.put(batch.size()); + } catch (Throwable t) { + errQueue.put(t); + } return null; - })); + }); } - } finally { - try { - while (!writeFutures.isEmpty()) { - writeFutures.remove().get(); - } - } finally { - writeFutures.forEach(f -> f.cancel(true)); + } + + partitionRecords.forEach((partition, batch) -> + SortShuffleManager.queueWriteTask(partitionSlotWriters.get(partition).getKey(), + () -> { + try { + final DiskBlockObjectWriter writer = + partitionSlotWriters.get(partition).getValue(); + batch.forEach(r -> writer.write(r._1(), r._2())); + finQueue.put(batch.size()); + } catch (Throwable t) { + errQueue.put(t); + } + return null; + })); + + Integer size; + Throwable thr; + while (inflight > 0) { + thr = errQueue.poll(50, TimeUnit.MILLISECONDS); + if (thr != null) { + throw new IOException(thr); + } + size = finQueue.poll(50, TimeUnit.MILLISECONDS); + if(size != null) { + inflight -= size; } } } else { @@ -208,8 +242,6 @@ public void write(Iterator> records) throws Exception { } } - final long timeSpanned = System.currentTimeMillis() - starttttt; - for (int i = 0; i < numPartitions; i++) { try (DiskBlockObjectWriter writer = partitionSlotWriters.get(i).getValue()) { partitionWriterSegments[i] = writer.commitAndGet();