Skip to content

Commit

Permalink
[CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid shuffle
Browse files Browse the repository at this point in the history
Co-authored-by: Xu Huang <[email protected]>
  • Loading branch information
reswqa and codenohup committed Oct 8, 2024
1 parent 3b2e70c commit ead982f
Show file tree
Hide file tree
Showing 8 changed files with 1,131 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.util.FrameDecoder;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;

public class TransportFrameDecoderWithBufferSupplier extends ChannelInboundHandlerAdapter
implements FrameDecoder {
Expand Down Expand Up @@ -123,8 +124,15 @@ private io.netty.buffer.ByteBuf decodeBodyCopyOut(
return buf;
}

ReadData readData = (ReadData) curMsg;
long streamId = readData.getStreamId();
long streamId;
if (curMsg instanceof ReadData) {
ReadData readData = (ReadData) curMsg;
streamId = readData.getStreamId();
} else {
SubPartitionReadData readData = (SubPartitionReadData) curMsg;
streamId = readData.getStreamId();
}

if (externalBuf == null) {
Supplier<ByteBuf> supplier = bufferSuppliers.get(streamId);
if (supplier == null) {
Expand All @@ -140,7 +148,11 @@ private io.netty.buffer.ByteBuf decodeBodyCopyOut(

copyByteBuf(buf, externalBuf, bodySize);
if (externalBuf.readableBytes() == bodySize) {
((ReadData) curMsg).setFlinkBuffer(externalBuf);
if (curMsg instanceof ReadData) {
((ReadData) curMsg).setFlinkBuffer(externalBuf);
} else {
((SubPartitionReadData) curMsg).setFlinkBuffer(externalBuf);
}
ctx.fireChannelRead(curMsg);
clear();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;

import javax.annotation.Nullable;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -36,6 +39,7 @@
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
Expand Down Expand Up @@ -108,10 +112,27 @@ public void onFailure(Throwable e) {
});
}

public void addCreditWithoutResponse(PbReadAddCredit pbReadAddCredit) {
this.client.sendRpc(
new TransportMessage(MessageType.READ_ADD_CREDIT, pbReadAddCredit.toByteArray())
.toByteBuffer());
}

public void notifyRequiredSegment(PbNotifyRequiredSegment pbNotifyRequiredSegment) {
this.client.sendRpc(
new TransportMessage(
MessageType.NOTIFY_REQUIRED_SEGMENT, pbNotifyRequiredSegment.toByteArray())
.toByteBuffer());
}

public static CelebornBufferStream empty() {
return EMPTY_CELEBORN_BUFFER_STREAM;
}

public static boolean isEmptyStream(CelebornBufferStream stream) {
return stream == null || stream == EMPTY_CELEBORN_BUFFER_STREAM;
}

public long getStreamId() {
return streamId;
}
Expand Down Expand Up @@ -165,6 +186,11 @@ public void close() {
}

public void moveToNextPartitionIfPossible(long endedStreamId) {
moveToNextPartitionIfPossible(endedStreamId, null);
}

public void moveToNextPartitionIfPossible(
long endedStreamId, @Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer) {
logger.debug(
"MoveToNextPartitionIfPossible in this:{}, endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
this,
Expand All @@ -176,9 +202,10 @@ public void moveToNextPartitionIfPossible(long endedStreamId) {
logger.debug("Get end streamId {}", endedStreamId);
cleanStream(endedStreamId);
}

if (currentLocationIndex.get() < locations.length) {
try {
openStreamInternal();
openStreamInternal(requiredSegmentIdConsumer);
logger.debug(
"MoveToNextPartitionIfPossible after openStream this:{}, endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
this,
Expand All @@ -193,7 +220,12 @@ public void moveToNextPartitionIfPossible(long endedStreamId) {
}
}

private void openStreamInternal() throws IOException, InterruptedException {
/**
* Open the stream, note that if the openReaderFuture is not null, requiredSegmentIdConsumer will
* be invoked for every subPartition when open stream success.
*/
private void openStreamInternal(@Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer)
throws IOException, InterruptedException {
this.client =
clientFactory.createClientWithRetry(
locations[currentLocationIndex.get()].getHost(),
Expand All @@ -208,6 +240,7 @@ private void openStreamInternal() throws IOException, InterruptedException {
.setStartIndex(subIndexStart)
.setEndIndex(subIndexEnd)
.setInitialCredit(initialCredit)
.setRequireSubpartitionId(true)
.build()
.toByteArray());
client.sendRpc(
Expand All @@ -228,6 +261,13 @@ public void onSuccess(ByteBuffer response) {
.getReadClientHandler()
.registerHandler(streamId, messageConsumer, client);
isOpenSuccess = true;
if (requiredSegmentIdConsumer != null) {
for (int subPartitionId = subIndexStart;
subPartitionId <= subIndexEnd;
subPartitionId++) {
requiredSegmentIdConsumer.accept(streamId, subPartitionId);
}
}
logger.debug(
"open stream success from remote:{}, stream id:{}, fileName: {}",
client.getSocketAddress(),
Expand Down Expand Up @@ -267,4 +307,8 @@ public void onFailure(Throwable e) {
public TransportClient getClient() {
return client;
}

public boolean isOpened() {
return isOpenSuccess;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,11 @@ public CelebornBufferStream readBufferedPartition(
shuffleId,
partitionId,
isSegmentGranularityVisible);
if (isSegmentGranularityVisible) {
// When the downstream reduce tasks start early than upstream map tasks, the shuffle
// partition locations may be found empty, should retry until the upstream task started
return CelebornBufferStream.empty();
} else {
throw new PartitionUnRetryAbleException(
String.format(
"Shuffle data lost for shuffle %d partition %d.", shuffleId, partitionId));
}
// TODO: in segment granularity visible senarios, when the downstream reduce tasks start early
// than upstream map tasks, the shuffle
// partition locations may be found empty, should retry until the upstream task started
throw new PartitionUnRetryAbleException(
String.format("Shuffle data lost for shuffle %d partition %d.", shuffleId, partitionId));
} else {
Arrays.sort(partitionLocations, Comparator.comparingInt(PartitionLocation::getEpoch));
logger.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -31,19 +33,41 @@
import io.netty.channel.ChannelHandlerContext;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;

import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RequestMessage;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.SubPartitionReadData;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;

@RunWith(Parameterized.class)
public class TransportFrameDecoderWithBufferSupplierSuiteJ {

enum TestReadDataType {
READ_DATA,
SUBPARTITION_READ_DATA,
}

private TestReadDataType testReadDataType;

public TransportFrameDecoderWithBufferSupplierSuiteJ(TestReadDataType testReadDataType) {
this.testReadDataType = testReadDataType;
}

@Parameterized.Parameters
public static Collection prepareData() {
Object[][] object = {{TestReadDataType.READ_DATA}, {TestReadDataType.SUBPARTITION_READ_DATA}};
return Arrays.asList(object);
}

@Test
public void testDropUnusedBytes() throws IOException {
ConcurrentHashMap<Long, Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
Expand All @@ -64,11 +88,11 @@ public void testDropUnusedBytes() throws IOException {
ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);

RpcRequest announcement = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData = new ReadData(1, generateData(1024));
ReadData readData = new ReadData(2, generateData(1024));
RequestMessage unUsedReadData = generateReadDataMessage(1, 0, generateData(1024));
RequestMessage readData = generateReadDataMessage(2, 0, generateData(1024));
RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
ReadData readData1 = new ReadData(2, generateData(8));
RequestMessage unUsedReadData1 = generateReadDataMessage(1, 0, generateData(1024));
RequestMessage readData1 = generateReadDataMessage(2, 0, generateData(8));

ByteBuf buffer = Unpooled.buffer(5000);
encodeMessage(announcement, buffer);
Expand Down Expand Up @@ -145,4 +169,12 @@ public ByteBuf generateData(int size) {

return data;
}

private RequestMessage generateReadDataMessage(long streamId, int subPartitionId, ByteBuf buf) {
if (testReadDataType == TestReadDataType.READ_DATA) {
return new ReadData(streamId, buf);
} else {
return new SubPartitionReadData(streamId, subPartitionId, buf);
}
}
}
Loading

0 comments on commit ead982f

Please sign in to comment.