Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CELEBORN-1490][CIP-6] Introduce tier consumer for hybrid shuffle #2786

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ private void processMessageInternal(long streamId, RequestMessage msg) {
} else {
if (msg != null && msg instanceof ReadData) {
((ReadData) msg).getFlinkBuffer().release();
} else if (msg != null && msg instanceof SubPartitionReadData) {
((SubPartitionReadData) msg).getFlinkBuffer().release();
}

logger.warn("Unexpected streamId received: {}", streamId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@

import org.apache.celeborn.common.network.protocol.RequestMessage;

public final class ReadData extends RequestMessage {
private final long streamId;
private ByteBuf flinkBuffer;
public class ReadData extends RequestMessage {
protected final long streamId;
protected ByteBuf flinkBuffer;

@Override
public boolean needCopyOut() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,30 @@

import java.util.Objects;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;

import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RequestMessage;

/**
* Comparing {@link ReadData}, this class has an additional field of subpartitionId. This class is
* added to keep the backward compatibility.
*/
public class SubPartitionReadData extends RequestMessage {
private final long streamId;
public class SubPartitionReadData extends ReadData {
private final int subPartitionId;
private ByteBuf flinkBuffer;

@Override
public boolean needCopyOut() {
return true;
}

public SubPartitionReadData(long streamId, int subPartitionId) {
super(streamId);
this.subPartitionId = subPartitionId;
this.streamId = streamId;
}

@Override
public int encodedLength() {
return 8 + 4;
return super.encodedLength() + 4;
}

// This method will not be called because ReadData won't be created at flink client.
@Override
public void encode(io.netty.buffer.ByteBuf buf) {
buf.writeLong(streamId);
super.encode(buf);
buf.writeInt(subPartitionId);
}

public long getStreamId() {
return streamId;
}

public int getSubPartitionId() {
return subPartitionId;
}
Expand All @@ -74,8 +58,8 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SubPartitionReadData readData = (SubPartitionReadData) o;
return streamId == readData.streamId
&& subPartitionId == readData.subPartitionId
return streamId == readData.getStreamId()
&& subPartitionId == readData.getSubPartitionId()
&& Objects.equals(flinkBuffer, readData.flinkBuffer);
}

Expand All @@ -95,12 +79,4 @@ public String toString() {
+ flinkBuffer
+ '}';
}

public ByteBuf getFlinkBuffer() {
return flinkBuffer;
}

public void setFlinkBuffer(ByteBuf flinkBuffer) {
this.flinkBuffer = flinkBuffer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;

import javax.annotation.Nullable;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -36,6 +39,7 @@
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PartitionLocation;
import org.apache.celeborn.common.protocol.PbBufferStreamEnd;
import org.apache.celeborn.common.protocol.PbNotifyRequiredSegment;
import org.apache.celeborn.common.protocol.PbOpenStream;
import org.apache.celeborn.common.protocol.PbReadAddCredit;
import org.apache.celeborn.common.protocol.PbStreamHandler;
Expand Down Expand Up @@ -108,10 +112,38 @@ public void onFailure(Throwable e) {
});
}

public void notifyRequiredSegment(PbNotifyRequiredSegment pbNotifyRequiredSegment) {
this.client.sendRpc(
reswqa marked this conversation as resolved.
Show resolved Hide resolved
new TransportMessage(
MessageType.NOTIFY_REQUIRED_SEGMENT, pbNotifyRequiredSegment.toByteArray())
.toByteBuffer(),
new RpcResponseCallback() {

@Override
public void onSuccess(ByteBuffer response) {
// Send PbNotifyRequiredSegment do not expect response.
}

@Override
public void onFailure(Throwable e) {
logger.error(
"Send PbNotifyRequiredSegment to {} failed, streamId {}, detail {}",
NettyUtils.getRemoteAddress(client.getChannel()),
streamId,
e.getCause());
messageConsumer.accept(new TransportableError(streamId, e));
}
});
}

public static CelebornBufferStream empty() {
return EMPTY_CELEBORN_BUFFER_STREAM;
}

public static boolean isEmptyStream(CelebornBufferStream stream) {
return stream == null || stream == EMPTY_CELEBORN_BUFFER_STREAM;
}

public long getStreamId() {
return streamId;
}
Expand Down Expand Up @@ -165,6 +197,11 @@ public void close() {
}

public void moveToNextPartitionIfPossible(long endedStreamId) {
moveToNextPartitionIfPossible(endedStreamId, null);
}

public void moveToNextPartitionIfPossible(
long endedStreamId, @Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer) {
logger.debug(
"MoveToNextPartitionIfPossible in this:{}, endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
this,
Expand All @@ -176,9 +213,10 @@ public void moveToNextPartitionIfPossible(long endedStreamId) {
logger.debug("Get end streamId {}", endedStreamId);
cleanStream(endedStreamId);
}

if (currentLocationIndex.get() < locations.length) {
try {
openStreamInternal();
openStreamInternal(requiredSegmentIdConsumer);
logger.debug(
"MoveToNextPartitionIfPossible after openStream this:{}, endedStreamId: {}, currentLocationIndex: {}, currentSteamId:{}, locationsLength:{}",
this,
Expand All @@ -193,7 +231,12 @@ public void moveToNextPartitionIfPossible(long endedStreamId) {
}
}

private void openStreamInternal() throws IOException, InterruptedException {
/**
* Open the stream, note that if the openReaderFuture is not null, requiredSegmentIdConsumer will
* be invoked for every subPartition when open stream success.
*/
private void openStreamInternal(@Nullable BiConsumer<Long, Integer> requiredSegmentIdConsumer)
throws IOException, InterruptedException {
this.client =
clientFactory.createClientWithRetry(
locations[currentLocationIndex.get()].getHost(),
Expand All @@ -208,6 +251,7 @@ private void openStreamInternal() throws IOException, InterruptedException {
.setStartIndex(subIndexStart)
.setEndIndex(subIndexEnd)
.setInitialCredit(initialCredit)
.setRequireSubpartitionId(true)
.build()
.toByteArray());
client.sendRpc(
Expand All @@ -228,6 +272,13 @@ public void onSuccess(ByteBuffer response) {
.getReadClientHandler()
.registerHandler(streamId, messageConsumer, client);
isOpenSuccess = true;
if (requiredSegmentIdConsumer != null) {
for (int subPartitionId = subIndexStart;
subPartitionId <= subIndexEnd;
subPartitionId++) {
requiredSegmentIdConsumer.accept(streamId, subPartitionId);
}
}
logger.debug(
"open stream success from remote:{}, stream id:{}, fileName: {}",
client.getSocketAddress(),
Expand Down Expand Up @@ -267,4 +318,8 @@ public void onFailure(Throwable e) {
public TransportClient getClient() {
return client;
}

public boolean isOpened() {
return isOpenSuccess;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,11 @@ public CelebornBufferStream readBufferedPartition(
shuffleId,
partitionId,
isSegmentGranularityVisible);
if (isSegmentGranularityVisible) {
// When the downstream reduce tasks start early than upstream map tasks, the shuffle
// partition locations may be found empty, should retry until the upstream task started
return CelebornBufferStream.empty();
} else {
throw new PartitionUnRetryAbleException(
String.format(
"Shuffle data lost for shuffle %d partition %d.", shuffleId, partitionId));
}
// TODO: in segment granularity visible senarios, when the downstream reduce tasks start early
SteNicholas marked this conversation as resolved.
Show resolved Hide resolved
// than upstream map tasks, the shuffle
// partition locations may be found empty, should retry until the upstream task started
throw new PartitionUnRetryAbleException(
String.format("Shuffle data lost for shuffle %d partition %d.", shuffleId, partitionId));
} else {
Arrays.sort(partitionLocations, Comparator.comparingInt(PartitionLocation::getEpoch));
logger.debug(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -31,19 +33,40 @@
import io.netty.channel.ChannelHandlerContext;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;

import org.apache.celeborn.common.network.buffer.NioManagedBuffer;
import org.apache.celeborn.common.network.protocol.Message;
import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RpcRequest;
import org.apache.celeborn.common.network.protocol.SubPartitionReadData;
import org.apache.celeborn.common.network.protocol.TransportMessage;
import org.apache.celeborn.common.protocol.MessageType;
import org.apache.celeborn.common.protocol.PbBacklogAnnouncement;
import org.apache.celeborn.common.util.JavaUtils;

@RunWith(Parameterized.class)
public class TransportFrameDecoderWithBufferSupplierSuiteJ {

enum TestReadDataType {
READ_DATA,
SUBPARTITION_READ_DATA,
}

private TestReadDataType testReadDataType;

public TransportFrameDecoderWithBufferSupplierSuiteJ(TestReadDataType testReadDataType) {
this.testReadDataType = testReadDataType;
}

@Parameterized.Parameters
public static Collection prepareData() {
Object[][] object = {{TestReadDataType.READ_DATA}, {TestReadDataType.SUBPARTITION_READ_DATA}};
return Arrays.asList(object);
}

@Test
public void testDropUnusedBytes() throws IOException {
ConcurrentHashMap<Long, Supplier<org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf>>
Expand All @@ -64,11 +87,11 @@ public void testDropUnusedBytes() throws IOException {
ChannelHandlerContext context = Mockito.mock(ChannelHandlerContext.class);

RpcRequest announcement = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData = new ReadData(1, generateData(1024));
ReadData readData = new ReadData(2, generateData(1024));
ReadData unUsedReadData = generateReadDataMessage(1, 0, generateData(1024));
ReadData readData = generateReadDataMessage(2, 0, generateData(1024));
RpcRequest announcement1 = createBacklogAnnouncement(0, 0);
ReadData unUsedReadData1 = new ReadData(1, generateData(1024));
ReadData readData1 = new ReadData(2, generateData(8));
ReadData unUsedReadData1 = generateReadDataMessage(1, 0, generateData(1024));
ReadData readData1 = generateReadDataMessage(2, 0, generateData(8));

ByteBuf buffer = Unpooled.buffer(5000);
encodeMessage(announcement, buffer);
Expand Down Expand Up @@ -145,4 +168,12 @@ public ByteBuf generateData(int size) {

return data;
}

private ReadData generateReadDataMessage(long streamId, int subPartitionId, ByteBuf buf) {
if (testReadDataType == TestReadDataType.READ_DATA) {
return new ReadData(streamId, buf);
} else {
return new SubPartitionReadData(streamId, subPartitionId, buf);
}
}
}
Loading
Loading