diff --git a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java index c81d3c7255..6575b87b87 100644 --- a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java +++ b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleWriteClientImpl.java @@ -583,7 +583,8 @@ public void registerShuffle( dataDistributionType, maxConcurrencyPerPartitionToWrite, stageAttemptNumber, - mergeContext); + mergeContext, + blockIdLayout); RssRegisterShuffleResponse response = getShuffleServerClient(shuffleServerInfo).registerShuffle(request); diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java index 6ebea551ab..35116aed37 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleServerInfo.java @@ -35,6 +35,8 @@ public class ShuffleServerInfo implements Serializable { private int nettyPort = -1; + private int serviceVersion = 0; + @VisibleForTesting public ShuffleServerInfo(String host, int port) { this.id = host + "-" + port; @@ -57,10 +59,16 @@ public ShuffleServerInfo(String id, String host, int port) { } public ShuffleServerInfo(String id, String host, int grpcPort, int nettyPort) { + this(id, host, grpcPort, nettyPort, 0); + } + + public ShuffleServerInfo( + String id, String host, int grpcPort, int nettyPort, int serviceVersion) { this.id = id; this.host = host; this.grpcPort = grpcPort; this.nettyPort = nettyPort; + this.serviceVersion = serviceVersion; } public String getId() { @@ -79,6 +87,10 @@ public int getNettyPort() { return nettyPort; } + public int getServiceVersion() { + return serviceVersion; + } + @Override public int hashCode() { // By default id = host + "-" + grpc port, if netty port is greater than 0, @@ -121,7 +133,8 @@ private static ShuffleServerInfo convertFromShuffleServerId( shuffleServerId.getId(), shuffleServerId.getIp(), shuffleServerId.getPort(), - shuffleServerId.getNettyPort()); + shuffleServerId.getNettyPort(), + shuffleServerId.getServiceVersion()); return shuffleServerInfo; } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java index fc70048808..13f564aa57 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Decoders.java @@ -36,7 +36,9 @@ public static ShuffleServerInfo decodeShuffleServerInfo(ByteBuf byteBuf) { String host = ByteBufUtils.readLengthAndString(byteBuf); int grpcPort = byteBuf.readInt(); int nettyPort = byteBuf.readInt(); - return new ShuffleServerInfo(id, host, grpcPort, nettyPort); + // this decodeShuffleServerInfo method is deprecated, + // clients do not need to encode service version + return new ShuffleServerInfo(id, host, grpcPort, nettyPort, 0); } public static ShuffleBlockInfo decodeShuffleBlockInfo(ByteBuf byteBuf) { diff --git a/common/src/main/java/org/apache/uniffle/common/rpc/ServiceVersion.java b/common/src/main/java/org/apache/uniffle/common/rpc/ServiceVersion.java new file mode 100644 index 0000000000..33f9c69c70 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/rpc/ServiceVersion.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.rpc; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; + +public class ServiceVersion { + public static final ServiceVersion NEWEST_VERSION = + new ServiceVersion(Feature.getNewest().getVersion()); + static final Map VALUE_MAP = + Arrays.stream(Feature.values()).collect(Collectors.toMap(Feature::getVersion, s -> s)); + private final int version; + + public ServiceVersion(int version) { + this.version = version; + } + + public Feature getCurrentFeature() { + return VALUE_MAP.get(version); + } + + public boolean supportFeature(Feature registerBlockIdLayout) { + return version >= registerBlockIdLayout.getVersion(); + } + + public int getVersion() { + return version; + } + + public enum Feature { + // Treat the old version as init version + INIT_VERSION(0), + // Register block id layout to server to avoid sending block id layout for each getShuffleResult + // request + REGISTER_BLOCK_ID_LAYOUT(1), + ; + + private final int version; + + Feature(int version) { + this.version = version; + } + + public int getVersion() { + return version; + } + + public static Feature getNewest() { + Feature[] enumConstants = Feature.class.getEnumConstants(); + return enumConstants[enumConstants.length - 1]; + } + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java b/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java index efba2b815c..fb8ac304e0 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java +++ b/common/src/main/java/org/apache/uniffle/common/util/BlockIdLayout.java @@ -21,6 +21,7 @@ import org.apache.uniffle.common.config.RssClientConf; import org.apache.uniffle.common.config.RssConf; +import org.apache.uniffle.proto.RssProtos; /** * This represents the actual bit layout of {@link BlockId}s. @@ -194,6 +195,14 @@ public BlockId asBlockId(int sequenceNo, int partitionId, long taskAttemptId) { (int) taskAttemptId); } + public RssProtos.BlockIdLayout toProto() { + return RssProtos.BlockIdLayout.newBuilder() + .setSequenceNoBits(sequenceNoBits) + .setPartitionIdBits(partitionIdBits) + .setTaskAttemptIdBits(taskAttemptIdBits) + .build(); + } + public static BlockIdLayout from(RssConf rssConf) { int sequenceBits = rssConf.get(RssClientConf.BLOCKID_SEQUENCE_NO_BITS); int partitionBits = rssConf.get(RssClientConf.BLOCKID_PARTITION_ID_BITS); diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java index e494300013..58740f7483 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/CoordinatorGrpcService.java @@ -538,7 +538,8 @@ private ServerNode toServerNode(ShuffleServerHeartBeatRequest request) { request.getStartTimeMs(), request.getVersion(), request.getGitCommitId(), - request.getApplicationInfoList()); + request.getApplicationInfoList(), + request.getServiceVersion()); } /** diff --git a/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java b/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java index 356c4bfe9d..f3306fd0e9 100644 --- a/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java +++ b/coordinator/src/main/java/org/apache/uniffle/coordinator/ServerNode.java @@ -23,6 +23,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -33,6 +34,7 @@ public class ServerNode implements Comparable { + private final int serviceVersion; private String id; private String ip; private int grpcPort; @@ -56,7 +58,7 @@ public ServerNode(String id) { this(id, "", 0, 0, 0, 0, 0, Sets.newHashSet(), ServerStatus.EXCLUDED); } - // Only for test + @VisibleForTesting public ServerNode( String id, String ip, @@ -129,6 +131,7 @@ public ServerNode( -1); } + @VisibleForTesting public ServerNode( String id, String ip, @@ -187,7 +190,8 @@ public ServerNode( startTime, "", "", - Collections.EMPTY_LIST); + Collections.EMPTY_LIST, + 0); } public ServerNode( @@ -206,7 +210,8 @@ public ServerNode( long startTime, String version, String gitCommitId, - List appInfos) { + List appInfos, + int serviceVersion) { this.id = id; this.ip = ip; this.grpcPort = grpcPort; @@ -230,6 +235,7 @@ public ServerNode( this.gitCommitId = gitCommitId; this.appIdToInfos = new ConcurrentHashMap<>(); appInfos.forEach(appInfo -> appIdToInfos.put(appInfo.getAppId(), appInfo)); + this.serviceVersion = serviceVersion; } public ShuffleServerId convertToGrpcProto() { @@ -239,6 +245,7 @@ public ShuffleServerId convertToGrpcProto() { .setPort(grpcPort) .setNettyPort(nettyPort) .setJettyPort(jettyPort) + .setServiceVersion(serviceVersion) .build(); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java index 8b4f5f91a2..90168abad1 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/factory/ShuffleServerClientFactory.java @@ -47,13 +47,17 @@ private ShuffleServerClient createShuffleServerClient( String clientType, ShuffleServerInfo shuffleServerInfo, RssConf rssConf) { if (clientType.equalsIgnoreCase(ClientType.GRPC.name())) { return new ShuffleServerGrpcClient( - rssConf, shuffleServerInfo.getHost(), shuffleServerInfo.getGrpcPort()); + rssConf, + shuffleServerInfo.getHost(), + shuffleServerInfo.getGrpcPort(), + shuffleServerInfo.getServiceVersion()); } else if (clientType.equalsIgnoreCase(ClientType.GRPC_NETTY.name())) { return new ShuffleServerGrpcNettyClient( rssConf, shuffleServerInfo.getHost(), shuffleServerInfo.getGrpcPort(), - shuffleServerInfo.getNettyPort()); + shuffleServerInfo.getNettyPort(), + shuffleServerInfo.getServiceVersion()); } else { throw new UnsupportedOperationException("Unsupported client type " + clientType); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java index fbfe578247..083567bb80 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/CoordinatorGrpcClient.java @@ -52,6 +52,7 @@ import org.apache.uniffle.common.ServerStatus; import org.apache.uniffle.common.ShuffleServerInfo; import org.apache.uniffle.common.exception.RssException; +import org.apache.uniffle.common.rpc.ServiceVersion; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.storage.StorageInfo; import org.apache.uniffle.common.storage.StorageInfoUtils; @@ -151,6 +152,7 @@ public ShuffleServerHeartBeatResponse doSendHeartBeat( .setVersion(Constants.VERSION) .setGitCommitId(Constants.REVISION_SHORT) .addAllApplicationInfo(appInfos) + .setServiceVersion(ServiceVersion.NEWEST_VERSION.getVersion()) .build(); RssProtos.StatusCode status; @@ -424,7 +426,11 @@ public Map> getPartitionToServers( .map( ss -> new ShuffleServerInfo( - ss.getId(), ss.getIp(), ss.getPort(), ss.getNettyPort())) + ss.getId(), + ss.getIp(), + ss.getPort(), + ss.getNettyPort(), + ss.getServiceVersion())) .collect(Collectors.toList()); for (int i = startPartition; i <= endPartition; i++) { partitionToServers.put(i, shuffleServerInfos); @@ -449,7 +455,12 @@ public Map> getServerToPartitionRanges( new PartitionRange(assign.getStartPartition(), assign.getEndPartition()); for (ShuffleServerId ssi : shuffleServerIds) { ShuffleServerInfo shuffleServerInfo = - new ShuffleServerInfo(ssi.getId(), ssi.getIp(), ssi.getPort(), ssi.getNettyPort()); + new ShuffleServerInfo( + ssi.getId(), + ssi.getIp(), + ssi.getPort(), + ssi.getNettyPort(), + ssi.getServiceVersion()); if (!serverToPartitionRanges.containsKey(shuffleServerInfo)) { serverToPartitionRanges.put(shuffleServerInfo, Lists.newArrayList()); } diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/GrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/GrpcClient.java index 2f5b09c789..58bd3f23dc 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/GrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/GrpcClient.java @@ -25,6 +25,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.uniffle.common.rpc.ServiceVersion; import org.apache.uniffle.common.util.GrpcNettyUtils; public abstract class GrpcClient { @@ -32,17 +33,19 @@ public abstract class GrpcClient { private static final Logger logger = LoggerFactory.getLogger(GrpcClient.class); protected String host; protected int port; + protected ServiceVersion serviceVersion; protected boolean usePlaintext; protected int maxRetryAttempts; protected ManagedChannel channel; protected GrpcClient(String host, int port, int maxRetryAttempts, boolean usePlaintext) { - this(host, port, maxRetryAttempts, usePlaintext, 0, 0, 0); + this(host, port, 0, maxRetryAttempts, usePlaintext, 0, 0, 0); } protected GrpcClient( String host, int port, + int serviceVersion, int maxRetryAttempts, boolean usePlaintext, int pageSize, @@ -50,6 +53,7 @@ protected GrpcClient( int smallCacheSize) { this.host = host; this.port = port; + this.serviceVersion = new ServiceVersion(serviceVersion); this.maxRetryAttempts = maxRetryAttempts; this.usePlaintext = usePlaintext; @@ -75,6 +79,10 @@ protected GrpcClient(ManagedChannel channel) { this.channel = channel; } + public ServiceVersion getServiceVersion() { + return serviceVersion; + } + public void close() { try { channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index 20b6bf98b1..db1cdd3058 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -76,7 +76,9 @@ import org.apache.uniffle.common.exception.RssException; import org.apache.uniffle.common.exception.RssFetchFailedException; import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; +import org.apache.uniffle.common.rpc.ServiceVersion; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.RetryUtils; import org.apache.uniffle.common.util.RssUtils; import org.apache.uniffle.proto.RssProtos; @@ -147,6 +149,7 @@ public ShuffleServerGrpcClient(String host, int port) { this( host, port, + 0, RssClientConf.RPC_MAX_ATTEMPTS.defaultValue(), RssClientConf.RPC_TIMEOUT_MS.defaultValue(), true, @@ -155,10 +158,11 @@ public ShuffleServerGrpcClient(String host, int port) { 0); } - public ShuffleServerGrpcClient(RssConf rssConf, String host, int port) { + public ShuffleServerGrpcClient(RssConf rssConf, String host, int port, int serviceVersion) { this( host, port, + serviceVersion, rssConf == null ? RssClientConf.RPC_MAX_ATTEMPTS.defaultValue() : rssConf.getInteger(RssClientConf.RPC_MAX_ATTEMPTS), @@ -174,13 +178,22 @@ public ShuffleServerGrpcClient(RssConf rssConf, String host, int port) { public ShuffleServerGrpcClient( String host, int port, + int serviceVersion, int maxRetryAttempts, long rpcTimeoutMs, boolean usePlaintext, int pageSize, int maxOrder, int smallCacheSize) { - super(host, port, maxRetryAttempts, usePlaintext, pageSize, maxOrder, smallCacheSize); + super( + host, + port, + serviceVersion, + maxRetryAttempts, + usePlaintext, + pageSize, + maxOrder, + smallCacheSize); blockingStub = ShuffleServerGrpc.newBlockingStub(channel); rpcTimeout = rpcTimeoutMs; } @@ -198,7 +211,8 @@ private ShuffleRegisterResponse doRegisterShuffle( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext) { + MergeContext mergeContext, + BlockIdLayout blockIdLayout) { ShuffleRegisterRequest.Builder reqBuilder = ShuffleRegisterRequest.newBuilder(); reqBuilder .setAppId(appId) @@ -222,6 +236,9 @@ private ShuffleRegisterResponse doRegisterShuffle( } } reqBuilder.setRemoteStorage(rsBuilder.build()); + if (blockIdLayout != null) { + reqBuilder.setBlockIdLayout(blockIdLayout.toProto()); + } return getBlockingStub().registerShuffle(reqBuilder.build()); } @@ -484,7 +501,8 @@ public RssRegisterShuffleResponse registerShuffle(RssRegisterShuffleRequest requ request.getDataDistributionType(), request.getMaxConcurrencyPerPartitionToWrite(), request.getStageAttemptNumber(), - request.getMergeContext()); + request.getMergeContext(), + request.getBlockIdLayout()); RssRegisterShuffleResponse response; RssProtos.StatusCode statusCode = rpcResponse.getStatus(); @@ -806,18 +824,21 @@ private ReportShuffleResultResponse doReportShuffleResult(ReportShuffleResultReq @Override public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest request) { - GetShuffleResultRequest rpcRequest = + GetShuffleResultRequest.Builder builder = GetShuffleResultRequest.newBuilder() .setAppId(request.getAppId()) .setShuffleId(request.getShuffleId()) - .setPartitionId(request.getPartitionId()) - .setBlockIdLayout( - RssProtos.BlockIdLayout.newBuilder() - .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits) - .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) - .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) - .build()) - .build(); + .setPartitionId(request.getPartitionId()); + if (!getServiceVersion().supportFeature(ServiceVersion.Feature.REGISTER_BLOCK_ID_LAYOUT)) { + // set blockIdLayout for each request for backward compatibility to the old server + builder.setBlockIdLayout( + RssProtos.BlockIdLayout.newBuilder() + .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits) + .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) + .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) + .build()); + } + GetShuffleResultRequest rpcRequest = builder.build(); GetShuffleResultResponse rpcResponse = getBlockingStub().getShuffleResult(rpcRequest); RssProtos.StatusCode statusCode = rpcResponse.getStatus(); @@ -854,18 +875,21 @@ public RssGetShuffleResultResponse getShuffleResult(RssGetShuffleResultRequest r @Override public RssGetShuffleResultResponse getShuffleResultForMultiPart( RssGetShuffleResultForMultiPartRequest request) { - GetShuffleResultForMultiPartRequest rpcRequest = + GetShuffleResultForMultiPartRequest.Builder builder = GetShuffleResultForMultiPartRequest.newBuilder() .setAppId(request.getAppId()) .setShuffleId(request.getShuffleId()) - .addAllPartitions(request.getPartitions()) - .setBlockIdLayout( - RssProtos.BlockIdLayout.newBuilder() - .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits) - .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) - .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) - .build()) - .build(); + .addAllPartitions(request.getPartitions()); + if (!getServiceVersion().supportFeature(ServiceVersion.Feature.REGISTER_BLOCK_ID_LAYOUT)) { + // set blockIdLayout for each request for backward compatibility to the old server + builder.setBlockIdLayout( + RssProtos.BlockIdLayout.newBuilder() + .setSequenceNoBits(request.getBlockIdLayout().sequenceNoBits) + .setPartitionIdBits(request.getBlockIdLayout().partitionIdBits) + .setTaskAttemptIdBits(request.getBlockIdLayout().taskAttemptIdBits) + .build()); + } + GetShuffleResultForMultiPartRequest rpcRequest = builder.build(); GetShuffleResultForMultiPartResponse rpcResponse = getBlockingStub().getShuffleResultForMultiPart(rpcRequest); RssProtos.StatusCode statusCode = rpcResponse.getStatus(); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index c2fde9176f..c1011cdc4a 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -63,15 +63,17 @@ public class ShuffleServerGrpcNettyClient extends ShuffleServerGrpcClient { @VisibleForTesting public ShuffleServerGrpcNettyClient(String host, int grpcPort, int nettyPort) { - this(new RssConf(), host, grpcPort, nettyPort); + this(new RssConf(), host, grpcPort, nettyPort, 0); } - public ShuffleServerGrpcNettyClient(RssConf rssConf, String host, int grpcPort, int nettyPort) { + public ShuffleServerGrpcNettyClient( + RssConf rssConf, String host, int grpcPort, int nettyPort, int serviceVersion) { this( rssConf == null ? new RssConf() : rssConf, host, grpcPort, nettyPort, + serviceVersion, rssConf == null ? RssClientConf.RPC_MAX_ATTEMPTS.defaultValue() : rssConf.getInteger(RssClientConf.RPC_MAX_ATTEMPTS), @@ -94,12 +96,22 @@ public ShuffleServerGrpcNettyClient( String host, int grpcPort, int nettyPort, + int serviceVersion, int maxRetryAttempts, long rpcTimeoutMs, int pageSize, int maxOrder, int smallCacheSize) { - super(host, grpcPort, maxRetryAttempts, rpcTimeoutMs, true, pageSize, maxOrder, smallCacheSize); + super( + host, + grpcPort, + serviceVersion, + maxRetryAttempts, + rpcTimeoutMs, + true, + pageSize, + maxOrder, + smallCacheSize); this.nettyPort = nettyPort; TransportContext transportContext = new TransportContext(new TransportConf(rssConf)); this.clientFactory = new TransportClientFactory(transportContext); diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java index 92ed1e15e9..54f367d3bd 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssRegisterShuffleRequest.java @@ -19,16 +19,19 @@ import java.util.List; +import com.google.common.annotations.VisibleForTesting; import org.apache.commons.lang3.StringUtils; import org.apache.uniffle.common.PartitionRange; import org.apache.uniffle.common.RemoteStorageInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; import org.apache.uniffle.common.config.RssClientConf; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.proto.RssProtos.MergeContext; public class RssRegisterShuffleRequest { + private final BlockIdLayout blockIdLayout; private String appId; private int shuffleId; private List partitionRanges; @@ -40,6 +43,7 @@ public class RssRegisterShuffleRequest { private final MergeContext mergeContext; + @VisibleForTesting public RssRegisterShuffleRequest( String appId, int shuffleId, @@ -57,6 +61,7 @@ public RssRegisterShuffleRequest( dataDistributionType, maxConcurrencyPerPartitionToWrite, 0, + null, null); } @@ -69,7 +74,8 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType dataDistributionType, int maxConcurrencyPerPartitionToWrite, int stageAttemptNumber, - MergeContext mergeContext) { + MergeContext mergeContext, + BlockIdLayout blockIdLayout) { this.appId = appId; this.shuffleId = shuffleId; this.partitionRanges = partitionRanges; @@ -79,8 +85,10 @@ public RssRegisterShuffleRequest( this.maxConcurrencyPerPartitionToWrite = maxConcurrencyPerPartitionToWrite; this.stageAttemptNumber = stageAttemptNumber; this.mergeContext = mergeContext; + this.blockIdLayout = blockIdLayout; } + @VisibleForTesting public RssRegisterShuffleRequest( String appId, int shuffleId, @@ -97,9 +105,11 @@ public RssRegisterShuffleRequest( dataDistributionType, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, + null, null); } + @VisibleForTesting public RssRegisterShuffleRequest( String appId, int shuffleId, List partitionRanges, String remoteStoragePath) { this( @@ -111,6 +121,7 @@ public RssRegisterShuffleRequest( ShuffleDataDistributionType.NORMAL, RssClientConf.MAX_CONCURRENCY_PER_PARTITION_TO_WRITE.defaultValue(), 0, + null, null); } @@ -149,4 +160,8 @@ public int getStageAttemptNumber() { public MergeContext getMergeContext() { return mergeContext; } + + public BlockIdLayout getBlockIdLayout() { + return blockIdLayout; + } } diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index d92ec40c7a..0a058ca686 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -197,6 +197,7 @@ message ShuffleRegisterRequest { int32 maxConcurrencyPerPartitionToWrite = 7; int32 stageAttemptNumber = 8; MergeContext mergeContext = 9; + BlockIdLayout blockIdLayout = 11; } enum DataDistribution { @@ -302,6 +303,7 @@ message ShuffleServerHeartBeatRequest { optional string gitCommitId = 23; optional int64 startTimeMs = 24; repeated ApplicationInfo applicationInfo = 25; + optional int32 serviceVersion = 26; } message ShuffleServerHeartBeatResponse { @@ -315,6 +317,7 @@ message ShuffleServerId { int32 port = 3; int32 netty_port = 4; int32 jetty_port = 5; + int32 serviceVersion = 6; } message ShuffleServerResult { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 994a25c890..7d7e8baf11 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -234,6 +234,14 @@ public void registerShuffle( String remoteStoragePath = req.getRemoteStorage().getPath(); String user = req.getUser(); int stageAttemptNumber = req.getStageAttemptNumber(); + BlockIdLayout blockIdLayout = null; + if (req.hasBlockIdLayout()) { + blockIdLayout = + BlockIdLayout.from( + req.getBlockIdLayout().getSequenceNoBits(), + req.getBlockIdLayout().getPartitionIdBits(), + req.getBlockIdLayout().getTaskAttemptIdBits()); + } auditContext.withAppId(appId).withShuffleId(shuffleId); auditContext.withArgs( "remoteStoragePath=" @@ -241,7 +249,9 @@ public void registerShuffle( + ", user=" + user + ", stageAttemptNumber=" - + stageAttemptNumber); + + stageAttemptNumber + + ", blockIdLayout=" + + blockIdLayout); // If the Stage is registered for the first time, you do not need to consider the Stage retry // and delete the Block data that has been sent. if (stageAttemptNumber > 0) { @@ -322,7 +332,8 @@ public void registerShuffle( new RemoteStorageInfo(remoteStoragePath, remoteStorageConf), user, shuffleDataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + blockIdLayout); if (StatusCode.SUCCESS == result && shuffleServer.isRemoteMergeEnable() && req.hasMergeContext()) { @@ -338,7 +349,8 @@ public void registerShuffle( new RemoteStorageInfo(remoteStoragePath, remoteStorageConf), user, shuffleDataDistributionType, - maxConcurrencyPerPartitionToWrite); + maxConcurrencyPerPartitionToWrite, + blockIdLayout); if (result == StatusCode.SUCCESS) { result = shuffleServer @@ -887,11 +899,14 @@ public void getShuffleResult( String appId = request.getAppId(); int shuffleId = request.getShuffleId(); int partitionId = request.getPartitionId(); - BlockIdLayout blockIdLayout = - BlockIdLayout.from( - request.getBlockIdLayout().getSequenceNoBits(), - request.getBlockIdLayout().getPartitionIdBits(), - request.getBlockIdLayout().getTaskAttemptIdBits()); + BlockIdLayout blockIdLayout = null; + if (request.hasBlockIdLayout()) { + blockIdLayout = + BlockIdLayout.from( + request.getBlockIdLayout().getSequenceNoBits(), + request.getBlockIdLayout().getPartitionIdBits(), + request.getBlockIdLayout().getTaskAttemptIdBits()); + } auditContext.withAppId(appId).withShuffleId(shuffleId); auditContext.withArgs("partitionId=" + partitionId + ", blockIdLayout=" + blockIdLayout); @@ -955,11 +970,14 @@ public void getShuffleResultForMultiPart( int shuffleId = request.getShuffleId(); List partitionsList = request.getPartitionsList(); - BlockIdLayout blockIdLayout = - BlockIdLayout.from( - request.getBlockIdLayout().getSequenceNoBits(), - request.getBlockIdLayout().getPartitionIdBits(), - request.getBlockIdLayout().getTaskAttemptIdBits()); + BlockIdLayout blockIdLayout = null; + if (request.hasBlockIdLayout()) { + blockIdLayout = + BlockIdLayout.from( + request.getBlockIdLayout().getSequenceNoBits(), + request.getBlockIdLayout().getPartitionIdBits(), + request.getBlockIdLayout().getTaskAttemptIdBits()); + } auditContext.withAppId(appId).withShuffleId(shuffleId); auditContext.withArgs( diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java index 94987d6614..f16718e936 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskInfo.java @@ -31,6 +31,7 @@ import org.apache.uniffle.common.PartitionInfo; import org.apache.uniffle.common.ShuffleDataDistributionType; +import org.apache.uniffle.common.util.BlockIdLayout; import org.apache.uniffle.common.util.JavaUtils; import org.apache.uniffle.common.util.UnitConverter; @@ -75,6 +76,8 @@ public class ShuffleTaskInfo { private final Map latestStageAttemptNumbers; + private BlockIdLayout blockIdLayout; + public ShuffleTaskInfo(String appId) { this.appId = appId; this.currentTimes = System.currentTimeMillis(); @@ -307,4 +310,12 @@ public String toString() { + shuffleDetailInfos + '}'; } + + public BlockIdLayout getBlockIdLayout() { + return blockIdLayout; + } + + public void setBlockIdLayout(BlockIdLayout blockIdLayout) { + this.blockIdLayout = blockIdLayout; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index fa531ba024..6954878f98 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -279,7 +279,8 @@ public StatusCode registerShuffle( remoteStorageInfo, user, ShuffleDataDistributionType.NORMAL, - -1); + -1, + null); } public StatusCode registerShuffle( @@ -289,7 +290,8 @@ public StatusCode registerShuffle( RemoteStorageInfo remoteStorageInfo, String user, ShuffleDataDistributionType dataDistType, - int maxConcurrencyPerPartitionToWrite) { + int maxConcurrencyPerPartitionToWrite, + BlockIdLayout blockIdLayout) { ReentrantReadWriteLock.WriteLock lock = getAppWriteLock(appId); lock.lock(); try { @@ -303,6 +305,7 @@ public StatusCode registerShuffle( getMaxConcurrencyWriting(maxConcurrencyPerPartitionToWrite, conf)) .dataDistributionType(dataDistType) .build()); + taskInfo.setBlockIdLayout(blockIdLayout); partitionsToBlockIds.computeIfAbsent(appId, key -> JavaUtils.newConcurrentMap()); for (PartitionRange partitionRange : partitionRanges) { @@ -629,6 +632,9 @@ public byte[] getFinishedBlockIds( } ShuffleTaskInfo taskInfo = getShuffleTaskInfo(appId); + if (blockIdLayout == null) { + blockIdLayout = taskInfo.getBlockIdLayout(); + } long expectedBlockNumber = 0; Map> bitmapIndexToPartitions = Maps.newHashMap(); for (int partitionId : partitions) {