Skip to content

Commit

Permalink
RATIS-2095. Extract common logic of ratis-shell to RaftUtils for reuse (
Browse files Browse the repository at this point in the history
  • Loading branch information
DaveTeng0 authored Jun 14, 2024
1 parent 25a41e3 commit 95ea26c
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 77 deletions.
124 changes: 123 additions & 1 deletion ratis-shell/src/main/java/org/apache/ratis/shell/cli/RaftUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,38 @@
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.RaftClientConfigKeys;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.protocol.GroupInfoReply;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.exceptions.RaftException;
import org.apache.ratis.retry.ExponentialBackoffRetry;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.function.CheckedFunction;

import java.io.IOException;
import java.io.PrintStream;
import java.net.InetSocketAddress;
import java.util.Properties;
import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.UUID;

/**
* Helper class for raft operations.
*/
public final class RaftUtils {

public static final RaftGroupId DEFAULT_RAFT_GROUP_ID = RaftGroupId.randomId();

private RaftUtils() {
// prevent instantiation
}
Expand Down Expand Up @@ -86,4 +104,108 @@ public static RaftClient createClient(RaftGroup raftGroup) {
.setRetryPolicy(retryPolicy)
.build();
}

/**
* Apply the given function to the given parameter a list.
*
* @param list the input parameter list
* @param function the function to be applied
* @param <PARAMETER> parameter type
* @param <RETURN> return value type
* @param <EXCEPTION> the exception type thrown by the given function.
* @return the first non-null value returned by the given function applied to the given list.
*/
private static <PARAMETER, RETURN, EXCEPTION extends Throwable> RETURN applyFunctionReturnFirstNonNull(
Collection<PARAMETER> list, CheckedFunction<PARAMETER, RETURN, EXCEPTION> function) {
for (PARAMETER parameter : list) {
try {
RETURN ret = function.apply(parameter);
if (ret != null) {
return ret;
}
} catch (Throwable e) {
e.printStackTrace();
}
}
return null;
}

public static List<RaftPeer> buildRaftPeersFromStr(String peers) {
List<InetSocketAddress> addresses = new ArrayList<>();
String[] peersArray = peers.split(",");
for (String peer : peersArray) {
addresses.add(parseInetSocketAddress(peer));
}

return addresses.stream()
.map(addr -> RaftPeer.newBuilder()
.setId(RaftUtils.getPeerId(addr))
.setAddress(addr)
.build()
).collect(Collectors.toList());
}

public static RaftGroupId buildRaftGroupIdFromStr(String groupId) {
return groupId != null && groupId.isEmpty() ? RaftGroupId.valueOf(UUID.fromString(groupId))
: DEFAULT_RAFT_GROUP_ID;
}

public static RaftGroupId retrieveRemoteGroupId(RaftGroupId raftGroupIdFromConfig,
List<RaftPeer> peers,
RaftClient client, PrintStream printStream) throws IOException {
if (!DEFAULT_RAFT_GROUP_ID .equals(raftGroupIdFromConfig)) {
return raftGroupIdFromConfig;
}

final RaftGroupId remoteGroupId;
final List<RaftGroupId> groupIds = applyFunctionReturnFirstNonNull(peers,
p -> client.getGroupManagementApi((p.getId())).list().getGroupIds());

if (groupIds == null) {
printStream.println("Failed to get group ID from " + peers);
throw new IOException("Failed to get group ID from " + peers);
} else if (groupIds.size() == 1) {
remoteGroupId = groupIds.get(0);
} else {
String message = "Unexpected multiple group IDs " + groupIds
+ ". In such case, the target group ID must be specified.";
printStream.println(message);
throw new IOException(message);
}
return remoteGroupId;
}

public static GroupInfoReply retrieveGroupInfoByGroupId(RaftGroupId remoteGroupId, List<RaftPeer> peers,
RaftClient client, PrintStream printStream)
throws IOException {
GroupInfoReply groupInfoReply = applyFunctionReturnFirstNonNull(peers,
p -> client.getGroupManagementApi((p.getId())).info(remoteGroupId));
processReply(groupInfoReply, printStream::println,
() -> "Failed to get group info for group id " + remoteGroupId.getUuid() + " from " + peers);
return groupInfoReply;
}

public static void processReply(RaftClientReply reply, Consumer<String> printer, Supplier<String> message)
throws IOException {
if (reply == null || !reply.isSuccess()) {
final RaftException e = Optional.ofNullable(reply)
.map(RaftClientReply::getException)
.orElseGet(() -> new RaftException("Reply: " + reply));
printer.accept(message.get());
throw new IOException(e.getMessage(), e);
}
}

public static InetSocketAddress parseInetSocketAddress(String address) {
try {
final String[] hostPortPair = address.split(":");
if (hostPortPair.length < 2) {
throw new IllegalArgumentException("Unexpected address format <HOST:PORT>.");
}
return new InetSocketAddress(hostPortPair[0], Integer.parseInt(hostPortPair[1]));
} catch (Exception e) {
throw new IllegalArgumentException("Failed to parse the server address parameter \"" + address + "\".", e);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
package org.apache.ratis.shell.cli.sh.command;

import org.apache.commons.cli.Option;
import org.apache.ratis.protocol.*;
import org.apache.ratis.protocol.exceptions.RaftException;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.protocol.GroupInfoReply;
import org.apache.ratis.shell.cli.RaftUtils;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
Expand All @@ -30,48 +34,30 @@
import org.apache.ratis.proto.RaftProtos.RaftPeerRole;
import org.apache.ratis.proto.RaftProtos.RoleInfoProto;
import org.apache.ratis.util.ProtoUtils;
import org.apache.ratis.util.function.CheckedFunction;

import java.io.IOException;
import java.io.PrintStream;
import java.net.InetSocketAddress;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.apache.ratis.shell.cli.RaftUtils.buildRaftGroupIdFromStr;
import static org.apache.ratis.shell.cli.RaftUtils.buildRaftPeersFromStr;
import static org.apache.ratis.shell.cli.RaftUtils.retrieveGroupInfoByGroupId;
import static org.apache.ratis.shell.cli.RaftUtils.retrieveRemoteGroupId;

/**
* The base class for the ratis shell which need to connect to server.
*/
public abstract class AbstractRatisCommand extends AbstractCommand {
public static final String PEER_OPTION_NAME = "peers";
public static final String GROUPID_OPTION_NAME = "groupid";
public static final RaftGroupId DEFAULT_RAFT_GROUP_ID = RaftGroupId.randomId();

/**
* Execute a given function with input parameter from the members of a list.
*
* @param list the input parameters
* @param function the function to be executed
* @param <T> parameter type
* @param <K> return value type
* @param <E> the exception type thrown by the given function.
* @return the value returned by the given function.
*/
public static <T, K, E extends Throwable> K run(Collection<T> list, CheckedFunction<T, K, E> function) {
for (T t : list) {
try {
K ret = function.apply(t);
if (ret != null) {
return ret;
}
} catch (Throwable e) {
e.printStackTrace();
}
}
return null;
}

private RaftGroup raftGroup;
private GroupInfoReply groupInfoReply;

Expand All @@ -81,46 +67,13 @@ protected AbstractRatisCommand(Context context) {

@Override
public int run(CommandLine cl) throws IOException {
List<InetSocketAddress> addresses = new ArrayList<>();
String peersStr = cl.getOptionValue(PEER_OPTION_NAME);
String[] peersArray = peersStr.split(",");
for (String peer : peersArray) {
addresses.add(parseInetSocketAddress(peer));
}

final RaftGroupId raftGroupIdFromConfig = cl.hasOption(GROUPID_OPTION_NAME)?
RaftGroupId.valueOf(UUID.fromString(cl.getOptionValue(GROUPID_OPTION_NAME)))
: DEFAULT_RAFT_GROUP_ID;

List<RaftPeer> peers = addresses.stream()
.map(addr -> RaftPeer.newBuilder()
.setId(RaftUtils.getPeerId(addr))
.setAddress(addr)
.build()
).collect(Collectors.toList());
List<RaftPeer> peers = buildRaftPeersFromStr(cl.getOptionValue(PEER_OPTION_NAME));
RaftGroupId raftGroupIdFromConfig = buildRaftGroupIdFromStr(cl.getOptionValue(GROUPID_OPTION_NAME));
raftGroup = RaftGroup.valueOf(raftGroupIdFromConfig, peers);
PrintStream printStream = getPrintStream();
try (final RaftClient client = RaftUtils.createClient(raftGroup)) {
final RaftGroupId remoteGroupId;
if (raftGroupIdFromConfig != DEFAULT_RAFT_GROUP_ID) {
remoteGroupId = raftGroupIdFromConfig;
} else {
final List<RaftGroupId> groupIds = run(peers,
p -> client.getGroupManagementApi((p.getId())).list().getGroupIds());

if (groupIds == null) {
println("Failed to get group ID from " + peers);
return -1;
} else if (groupIds.size() == 1) {
remoteGroupId = groupIds.get(0);
} else {
println("There are more than one groups, you should specific one. " + groupIds);
return -2;
}
}

groupInfoReply = run(peers, p -> client.getGroupManagementApi((p.getId())).info(remoteGroupId));
processReply(groupInfoReply,
() -> "Failed to get group info for group id " + remoteGroupId.getUuid() + " from " + peers);
final RaftGroupId remoteGroupId = retrieveRemoteGroupId(raftGroupIdFromConfig, peers, client, printStream);
groupInfoReply = retrieveGroupInfoByGroupId(remoteGroupId, peers, client, printStream);
raftGroup = groupInfoReply.getGroup();
}
return 0;
Expand Down Expand Up @@ -168,14 +121,7 @@ protected RaftPeerProto getLeader(RoleInfoProto roleInfo) {
}

protected void processReply(RaftClientReply reply, Supplier<String> messageSupplier) throws IOException {
if (reply == null || !reply.isSuccess()) {
final RaftException e = Optional.ofNullable(reply)
.map(RaftClientReply::getException)
.orElseGet(() -> new RaftException("Reply: " + reply));
final String message = messageSupplier.get();
printf("%s. Error: %s%n", message, e);
throw new IOException(message, e);
}
RaftUtils.processReply(reply, getPrintStream()::println, messageSupplier);
}

protected List<RaftPeerId> getIds(String[] optionValues, BiConsumer<RaftPeerId, InetSocketAddress> consumer) {
Expand Down

0 comments on commit 95ea26c

Please sign in to comment.