diff --git a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java index bc72af8a634..f8fc35a61c1 100644 --- a/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java +++ b/client-flink/common/src/main/java/org/apache/celeborn/plugin/flink/readclient/CelebornBufferStream.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Supplier; @@ -44,7 +45,7 @@ public class CelebornBufferStream { private int subIndexStart; private int subIndexEnd; private TransportClient client; - private int currentLocationIndex = 0; + private AtomicInteger currentLocationIndex = new AtomicInteger(0); private long streamId = 0; private FlinkShuffleClientImpl mapShuffleClient; private boolean isClosed; @@ -153,7 +154,7 @@ public void moveToNextPartitionIfPossible(long endedStreamId) { endedStreamId, currentLocationIndex, streamId); - if (currentLocationIndex > 0) { + if (currentLocationIndex.get() > 0) { if (endedStreamId == streamId) { logger.debug("Get end streamId {}", endedStreamId); cleanStream(endedStreamId); @@ -165,10 +166,10 @@ public void moveToNextPartitionIfPossible(long endedStreamId) { return; } } - if (currentLocationIndex < locations.length) { + if (currentLocationIndex.get() < locations.length) { try { openStreamInternal(); - currentLocationIndex++; + currentLocationIndex.incrementAndGet(); } catch (Exception e) { logger.warn("Failed to open stream and report to flink framework. ", e); messageConsumer.accept(new TransportableError(0L, e)); @@ -179,9 +180,9 @@ public void moveToNextPartitionIfPossible(long endedStreamId) { private void openStreamInternal() throws IOException, InterruptedException { this.client = clientFactory.createClientWithRetry( - locations[currentLocationIndex].getHost(), - locations[currentLocationIndex].getFetchPort()); - String fileName = locations[currentLocationIndex].getFileName(); + locations[currentLocationIndex.get()].getHost(), + locations[currentLocationIndex.get()].getFetchPort()); + String fileName = locations[currentLocationIndex.get()].getFileName(); TransportMessage openStream = new TransportMessage( MessageType.OPEN_STREAM,