Skip to content

Commit

Permalink
Introduce ServiceVersion framework to support backward compatible and…
Browse files Browse the repository at this point in the history
… reduce BlcokIdLayout message
  • Loading branch information
maobaolong committed Oct 22, 2024
1 parent 2b70eb4 commit dd6f6b7
Show file tree
Hide file tree
Showing 17 changed files with 268 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,8 @@ public void registerShuffle(
dataDistributionType,
maxConcurrencyPerPartitionToWrite,
stageAttemptNumber,
mergeContext);
mergeContext,
blockIdLayout);
RssRegisterShuffleResponse response =
getShuffleServerClient(shuffleServerInfo).registerShuffle(request);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ public class ShuffleServerInfo implements Serializable {

private int nettyPort = -1;

private int serviceVersion = 0;

@VisibleForTesting
public ShuffleServerInfo(String host, int port) {
this.id = host + "-" + port;
Expand All @@ -57,10 +59,16 @@ public ShuffleServerInfo(String id, String host, int port) {
}

public ShuffleServerInfo(String id, String host, int grpcPort, int nettyPort) {
this(id, host, grpcPort, nettyPort, 0);
}

public ShuffleServerInfo(
String id, String host, int grpcPort, int nettyPort, int serviceVersion) {
this.id = id;
this.host = host;
this.grpcPort = grpcPort;
this.nettyPort = nettyPort;
this.serviceVersion = serviceVersion;
}

public String getId() {
Expand All @@ -79,6 +87,10 @@ public int getNettyPort() {
return nettyPort;
}

public int getServiceVersion() {
return serviceVersion;
}

@Override
public int hashCode() {
// By default id = host + "-" + grpc port, if netty port is greater than 0,
Expand Down Expand Up @@ -121,7 +133,8 @@ private static ShuffleServerInfo convertFromShuffleServerId(
shuffleServerId.getId(),
shuffleServerId.getIp(),
shuffleServerId.getPort(),
shuffleServerId.getNettyPort());
shuffleServerId.getNettyPort(),
shuffleServerId.getServiceVersion());
return shuffleServerInfo;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) {
String host = ByteBufUtils.readLengthAndString(byteBuf);
int grpcPort = byteBuf.readInt();
int nettyPort = byteBuf.readInt();
return new ShuffleServerInfo(id, host, grpcPort, nettyPort);
// this decodeShuffleServerInfo method is deprecated,
// clients do not need to encode service version
return new ShuffleServerInfo(id, host, grpcPort, nettyPort, 0);
}

public static ShuffleBlockInfo decodeShuffleBlockInfo(ByteBuf byteBuf) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.uniffle.common.rpc;

import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;

public class ServiceVersion {
public static final ServiceVersion NEWEST_VERSION =
new ServiceVersion(Feature.getNewest().getVersion());
static final Map<Integer, Feature> VALUE_MAP =
Arrays.stream(Feature.values()).collect(Collectors.toMap(Feature::getVersion, s -> s));
private final int version;

public ServiceVersion(int version) {
this.version = version;
}

public Feature getCurrentFeature() {
return VALUE_MAP.get(version);
}

public boolean supportFeature(Feature registerBlockIdLayout) {
return version >= registerBlockIdLayout.getVersion();
}

public int getVersion() {
return version;
}

public enum Feature {
// Treat the old version as init version
INIT_VERSION(0),
// Register block id layout to server to avoid sending block id layout for each getShuffleResult
// request
REGISTER_BLOCK_ID_LAYOUT(1),
;

private final int version;

Feature(int version) {
this.version = version;
}

public int getVersion() {
return version;
}

public static Feature getNewest() {
Feature[] enumConstants = Feature.class.getEnumConstants();
return enumConstants[enumConstants.length - 1];
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.proto.RssProtos;

/**
* This represents the actual bit layout of {@link BlockId}s.
Expand Down Expand Up @@ -194,6 +195,14 @@ public BlockId asBlockId(int sequenceNo, int partitionId, long taskAttemptId) {
(int) taskAttemptId);
}

public RssProtos.BlockIdLayout toProto() {
return RssProtos.BlockIdLayout.newBuilder()
.setSequenceNoBits(sequenceNoBits)
.setPartitionIdBits(partitionIdBits)
.setTaskAttemptIdBits(taskAttemptIdBits)
.build();
}

public static BlockIdLayout from(RssConf rssConf) {
int sequenceBits = rssConf.get(RssClientConf.BLOCKID_SEQUENCE_NO_BITS);
int partitionBits = rssConf.get(RssClientConf.BLOCKID_PARTITION_ID_BITS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ private ServerNode toServerNode(ShuffleServerHeartBeatRequest request) {
request.getStartTimeMs(),
request.getVersion(),
request.getGitCommitId(),
request.getApplicationInfoList());
request.getApplicationInfoList(),
request.getServiceVersion());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

Expand All @@ -33,6 +34,7 @@

public class ServerNode implements Comparable<ServerNode> {

private final int serviceVersion;
private String id;
private String ip;
private int grpcPort;
Expand All @@ -56,7 +58,7 @@ public ServerNode(String id) {
this(id, "", 0, 0, 0, 0, 0, Sets.newHashSet(), ServerStatus.EXCLUDED);
}

// Only for test
@VisibleForTesting
public ServerNode(
String id,
String ip,
Expand Down Expand Up @@ -129,6 +131,7 @@ public ServerNode(
-1);
}

@VisibleForTesting
public ServerNode(
String id,
String ip,
Expand Down Expand Up @@ -187,7 +190,8 @@ public ServerNode(
startTime,
"",
"",
Collections.EMPTY_LIST);
Collections.EMPTY_LIST,
0);
}

public ServerNode(
Expand All @@ -206,7 +210,8 @@ public ServerNode(
long startTime,
String version,
String gitCommitId,
List<RssProtos.ApplicationInfo> appInfos) {
List<RssProtos.ApplicationInfo> appInfos,
int serviceVersion) {
this.id = id;
this.ip = ip;
this.grpcPort = grpcPort;
Expand All @@ -230,6 +235,7 @@ public ServerNode(
this.gitCommitId = gitCommitId;
this.appIdToInfos = new ConcurrentHashMap<>();
appInfos.forEach(appInfo -> appIdToInfos.put(appInfo.getAppId(), appInfo));
this.serviceVersion = serviceVersion;
}

public ShuffleServerId convertToGrpcProto() {
Expand All @@ -239,6 +245,7 @@ public ShuffleServerId convertToGrpcProto() {
.setPort(grpcPort)
.setNettyPort(nettyPort)
.setJettyPort(jettyPort)
.setServiceVersion(serviceVersion)
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ private ShuffleServerClient createShuffleServerClient(
String clientType, ShuffleServerInfo shuffleServerInfo, RssConf rssConf) {
if (clientType.equalsIgnoreCase(ClientType.GRPC.name())) {
return new ShuffleServerGrpcClient(
rssConf, shuffleServerInfo.getHost(), shuffleServerInfo.getGrpcPort());
rssConf,
shuffleServerInfo.getHost(),
shuffleServerInfo.getGrpcPort(),
shuffleServerInfo.getServiceVersion());
} else if (clientType.equalsIgnoreCase(ClientType.GRPC_NETTY.name())) {
return new ShuffleServerGrpcNettyClient(
rssConf,
shuffleServerInfo.getHost(),
shuffleServerInfo.getGrpcPort(),
shuffleServerInfo.getNettyPort());
shuffleServerInfo.getNettyPort(),
shuffleServerInfo.getServiceVersion());
} else {
throw new UnsupportedOperationException("Unsupported client type " + clientType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.uniffle.common.ServerStatus;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.ServiceVersion;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.storage.StorageInfo;
import org.apache.uniffle.common.storage.StorageInfoUtils;
Expand Down Expand Up @@ -151,6 +152,7 @@ public ShuffleServerHeartBeatResponse doSendHeartBeat(
.setVersion(Constants.VERSION)
.setGitCommitId(Constants.REVISION_SHORT)
.addAllApplicationInfo(appInfos)
.setServiceVersion(ServiceVersion.NEWEST_VERSION.getVersion())
.build();

RssProtos.StatusCode status;
Expand Down Expand Up @@ -424,7 +426,11 @@ public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers(
.map(
ss ->
new ShuffleServerInfo(
ss.getId(), ss.getIp(), ss.getPort(), ss.getNettyPort()))
ss.getId(),
ss.getIp(),
ss.getPort(),
ss.getNettyPort(),
ss.getServiceVersion()))
.collect(Collectors.toList());
for (int i = startPartition; i <= endPartition; i++) {
partitionToServers.put(i, shuffleServerInfos);
Expand All @@ -449,7 +455,12 @@ public Map<ShuffleServerInfo, List<PartitionRange>> getServerToPartitionRanges(
new PartitionRange(assign.getStartPartition(), assign.getEndPartition());
for (ShuffleServerId ssi : shuffleServerIds) {
ShuffleServerInfo shuffleServerInfo =
new ShuffleServerInfo(ssi.getId(), ssi.getIp(), ssi.getPort(), ssi.getNettyPort());
new ShuffleServerInfo(
ssi.getId(),
ssi.getIp(),
ssi.getPort(),
ssi.getNettyPort(),
ssi.getServiceVersion());
if (!serverToPartitionRanges.containsKey(shuffleServerInfo)) {
serverToPartitionRanges.put(shuffleServerInfo, Lists.newArrayList());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,35 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.uniffle.common.rpc.ServiceVersion;
import org.apache.uniffle.common.util.GrpcNettyUtils;

public abstract class GrpcClient {

private static final Logger logger = LoggerFactory.getLogger(GrpcClient.class);
protected String host;
protected int port;
protected ServiceVersion serviceVersion;
protected boolean usePlaintext;
protected int maxRetryAttempts;
protected ManagedChannel channel;

protected GrpcClient(String host, int port, int maxRetryAttempts, boolean usePlaintext) {
this(host, port, maxRetryAttempts, usePlaintext, 0, 0, 0);
this(host, port, 0, maxRetryAttempts, usePlaintext, 0, 0, 0);
}

protected GrpcClient(
String host,
int port,
int serviceVersion,
int maxRetryAttempts,
boolean usePlaintext,
int pageSize,
int maxOrder,
int smallCacheSize) {
this.host = host;
this.port = port;
this.serviceVersion = new ServiceVersion(serviceVersion);
this.maxRetryAttempts = maxRetryAttempts;
this.usePlaintext = usePlaintext;

Expand All @@ -75,6 +79,10 @@ protected GrpcClient(ManagedChannel channel) {
this.channel = channel;
}

public ServiceVersion getServiceVersion() {
return serviceVersion;
}

public void close() {
try {
channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
Expand Down
Loading

0 comments on commit dd6f6b7

Please sign in to comment.