Skip to content

Commit

Permalink
[CELEBORN-1490][CIP-6] Extends message to support hybrid shuffle
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This is the first PR to support Hybrid Shuffle.

Extends message to support hybrid shuffle.

### Why are the changes needed?
hybrid shuffle is a tiered storage architecture, which introduces the concept of `segment`. One segment's data selects a tier to send. Data is split into segments and sent to multiple tiers.

This PR introduces segment-related message. In addition, hybrid shuffle needs to distinguish which subpartition it comes from when consuming data, so we need to extend the `SubpartitionId` field to `ReadData` (new class introduced for compatibility).

### Does this PR introduce _any_ user-facing change?
no.

### How was this patch tested?
no need.

Closes apache#2714 from reswqa/cip6-1-extend-message.

Authored-by: Weijie Guo <[email protected]>
Signed-off-by: Shuang <[email protected]>
  • Loading branch information
reswqa authored and RexXiong committed Aug 30, 2024
1 parent 3853075 commit 5d61458
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.celeborn.common.network.protocol.*;
import org.apache.celeborn.plugin.flink.buffer.FlinkNettyManagedBuffer;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;

public class MessageDecoderExt {
public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody) {
Expand Down Expand Up @@ -74,6 +75,11 @@ public static Message decode(Message.Type type, ByteBuf in, boolean decodeBody)
streamId = in.readLong();
return new ReadData(streamId);

case SUBPARTITION_READ_DATA:
streamId = in.readLong();
int subPartitionId = in.readInt();
return new SubPartitionReadData(streamId, subPartitionId);

case BACKLOG_ANNOUNCEMENT:
streamId = in.readLong();
int backlog = in.readInt();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.celeborn.common.network.server.BaseMessageHandler;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.plugin.flink.protocol.ReadData;
import org.apache.celeborn.plugin.flink.protocol.SubPartitionReadData;

public class ReadClientHandler extends BaseMessageHandler {
private static Logger logger = LoggerFactory.getLogger(ReadClientHandler.class);
Expand All @@ -65,6 +66,8 @@ private void processMessageInternal(long streamId, RequestMessage msg) {
} else {
if (msg != null && msg instanceof ReadData) {
((ReadData) msg).getFlinkBuffer().release();
} else if (msg != null && msg instanceof SubPartitionReadData) {
((SubPartitionReadData) msg).getFlinkBuffer().release();
}

logger.warn("Unexpected streamId received: {}", streamId);
Expand All @@ -83,6 +86,10 @@ public void receive(TransportClient client, RequestMessage msg) {
ReadData readData = (ReadData) msg;
processMessageInternal(readData.getStreamId(), readData);
break;
case SUBPARTITION_READ_DATA:
SubPartitionReadData subPartitionReadData = (SubPartitionReadData) msg;
processMessageInternal(subPartitionReadData.getStreamId(), subPartitionReadData);
break;
case BACKLOG_ANNOUNCEMENT:
BacklogAnnouncement backlogAnnouncement = (BacklogAnnouncement) msg;
processMessageInternal(backlogAnnouncement.getStreamId(), backlogAnnouncement);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.plugin.flink.protocol;

import java.util.Objects;

import org.apache.flink.shaded.netty4.io.netty.buffer.ByteBuf;

import org.apache.celeborn.common.network.protocol.ReadData;
import org.apache.celeborn.common.network.protocol.RequestMessage;

/**
* Comparing {@link ReadData}, this class has an additional field of subpartitionId. This class is
* added to keep the backward compatibility.
*/
public class SubPartitionReadData extends RequestMessage {
private final long streamId;
private final int subPartitionId;
private ByteBuf flinkBuffer;

@Override
public boolean needCopyOut() {
return true;
}

public SubPartitionReadData(long streamId, int subPartitionId) {
this.subPartitionId = subPartitionId;
this.streamId = streamId;
}

@Override
public int encodedLength() {
return 8 + 4;
}

// This method will not be called because ReadData won't be created at flink client.
@Override
public void encode(io.netty.buffer.ByteBuf buf) {
buf.writeLong(streamId);
buf.writeInt(subPartitionId);
}

public long getStreamId() {
return streamId;
}

public int getSubPartitionId() {
return subPartitionId;
}

@Override
public Type type() {
return Type.SUBPARTITION_READ_DATA;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SubPartitionReadData readData = (SubPartitionReadData) o;
return streamId == readData.streamId
&& subPartitionId == readData.subPartitionId
&& Objects.equals(flinkBuffer, readData.flinkBuffer);
}

@Override
public int hashCode() {
return Objects.hash(streamId, subPartitionId, flinkBuffer);
}

@Override
public String toString() {
return "SubpartitionReadData{"
+ "streamId="
+ streamId
+ ", subPartitionId="
+ subPartitionId
+ ", flinkBuffer="
+ flinkBuffer
+ '}';
}

public ByteBuf getFlinkBuffer() {
return flinkBuffer;
}

public void setFlinkBuffer(ByteBuf flinkBuffer) {
this.flinkBuffer = flinkBuffer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ public enum Type implements Encodable {
BACKLOG_ANNOUNCEMENT(19),
TRANSPORTABLE_ERROR(20),
BUFFER_STREAM_END(21),
HEARTBEAT(22);
HEARTBEAT(22),
SEGMENT_START(23),
NOTIFY_REQUIRED_SEGMENT(24),
SUBPARTITION_READ_DATA(25);

private final byte id;

Type(int id) {
Expand Down Expand Up @@ -164,6 +168,12 @@ public static Type decode(ByteBuf buf) {
return BUFFER_STREAM_END;
case 22:
return HEARTBEAT;
case 23:
return SEGMENT_START;
case 24:
return NOTIFY_REQUIRED_SEGMENT;
case 25:
return SUBPARTITION_READ_DATA;
case -1:
throw new IllegalArgumentException("User type messages cannot be decoded.");
default:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.celeborn.common.network.protocol;

import java.util.Objects;

import io.netty.buffer.ByteBuf;

import org.apache.celeborn.common.network.buffer.NettyManagedBuffer;

/**
* Comparing {@link ReadData}, this class has an additional field of subpartitionId. This class is
* added to keep the backward compatibility.
*/
public class SubPartitionReadData extends RequestMessage {
private long streamId;

private int subPartitionId;

public SubPartitionReadData(long streamId, int subPartitionId, ByteBuf buf) {
super(new NettyManagedBuffer(buf));
this.streamId = streamId;
this.subPartitionId = subPartitionId;
}

@Override
public int encodedLength() {
return 8 + 4;
}

@Override
public void encode(ByteBuf buf) {
buf.writeLong(streamId);
buf.writeInt(subPartitionId);
}

public long getStreamId() {
return streamId;
}

public int getSubPartitionId() {
return subPartitionId;
}

@Override
public Type type() {
return Type.SUBPARTITION_READ_DATA;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SubPartitionReadData readData = (SubPartitionReadData) o;
return streamId == readData.streamId
&& subPartitionId == readData.subPartitionId
&& super.equals(o);
}

@Override
public int hashCode() {
return Objects.hash(streamId, subPartitionId, super.hashCode());
}

@Override
public String toString() {
return "SubpartitionReadData{"
+ "streamId="
+ streamId
+ ", subPartitionId="
+ subPartitionId
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ public <T extends GeneratedMessageV3> T getParsedPayload() throws InvalidProtoco
return (T) PbOpenStreamList.parseFrom(payload);
case BATCH_OPEN_STREAM_RESPONSE_VALUE:
return (T) PbOpenStreamListResponse.parseFrom(payload);
case SEGMENT_START_VALUE:
return (T) PbSegmentStart.parseFrom(payload);
case NOTIFY_REQUIRED_SEGMENT_VALUE:
return (T) PbNotifyRequiredSegment.parseFrom(payload);
default:
logger.error("Unexpected type {}", type);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ public enum StatusCode {
DESTROY_SLOTS_MOCK_FAILURE(48),
COMMIT_FILES_MOCK_FAILURE(49),
PUSH_DATA_FAIL_NON_CRITICAL_CAUSE_REPLICA(50),
OPEN_STREAM_FAILED(51);
OPEN_STREAM_FAILED(51),
SEGMENT_START_FAIL_REPLICA(52),
SEGMENT_START_FAIL_PRIMARY(53);

private final byte value;

Expand Down
28 changes: 28 additions & 0 deletions common/src/main/proto/TransportMessages.proto
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ enum MessageType {
REPORT_WORKER_DECOMMISSION = 82;
REPORT_BARRIER_STAGE_ATTEMPT_FAILURE = 83;
REPORT_BARRIER_STAGE_ATTEMPT_FAILURE_RESPONSE = 84;
SEGMENT_START = 85;
NOTIFY_REQUIRED_SEGMENT = 86;
}

enum StreamType {
Expand Down Expand Up @@ -258,6 +260,7 @@ message PbRegisterMapPartitionTask {
int32 mapId = 3;
int32 attemptId = 4;
int32 partitionId = 5;
bool isSegmentGranularityVisible = 6;
}

message PbRegisterShuffleResponse {
Expand Down Expand Up @@ -353,6 +356,7 @@ message PbMapperEndResponse {

message PbGetReducerFileGroup {
int32 shuffleId = 1;
bool isSegmentGranularityVisible = 2;
}

message PbGetReducerFileGroupResponse {
Expand Down Expand Up @@ -471,6 +475,7 @@ message PbReserveSlots {
bool partitionSplitEnabled = 11;
int32 availableStorageTypes = 12;
PbPackedPartitionLocationsPair partitionLocationsPair = 13;
bool isSegmentGranularityVisible = 14;
}

message PbReserveSlotsResponse {
Expand Down Expand Up @@ -576,6 +581,13 @@ message PbFileInfo {
int32 numSubpartitions = 6;
int64 bytesFlushed = 7;
bool partitionSplitEnabled = 8;
bool isSegmentGranularityVisible = 9;
map<int32, int32> partitionWritingSegment = 10;
repeated PbSegmentIndex segmentIndex = 11;
}

message PbSegmentIndex {
map<int32, int32> firstBufferIndexToSegment = 1;
}

message PbMapFileMeta {
Expand Down Expand Up @@ -653,6 +665,7 @@ message PbOpenStream {
int32 endIndex = 4;
int32 initialCredit = 5;
bool readLocalShuffle = 6;
bool requireSubpartitionId = 7;
}

message PbStreamHandler {
Expand Down Expand Up @@ -760,6 +773,21 @@ message PbAuthenticationInitiationRequest {
repeated PbSaslMechanism saslMechanisms = 3;
}

message PbSegmentStart {
PbPartitionLocation.Mode mode = 1;
string shuffleKey = 2;
string partitionUniqueId = 3;
int32 attemptId = 4;
int32 subPartitionId = 5;
int32 segmentId = 6;
}

message PbNotifyRequiredSegment {
int64 streamId = 1;
int32 requiredSegmentId = 2;
int32 subPartitionId = 3;
}

message PbAuthenticationInitiationResponse {
string version = 1;
bool authEnabled = 2;
Expand Down

0 comments on commit 5d61458

Please sign in to comment.