Skip to content

Commit

Permalink
RATIS-2095. Move common logic of ratis-shell to RaftUtils so that Ozo…
Browse files Browse the repository at this point in the history
…ne shell could share and use common logic
  • Loading branch information
DaveTeng0 committed May 21, 2024
1 parent 29bba59 commit ca41aca
Show file tree
Hide file tree
Showing 13 changed files with 310 additions and 62 deletions.
8 changes: 8 additions & 0 deletions ratis-shell/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</dependency>
<dependency>
<groupId>org.apache.ratis</groupId>
<artifactId>ratis-grpc</artifactId>
</dependency>
<dependency>
<groupId>org.apache.ratis</groupId>
<artifactId>ratis-netty</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
165 changes: 162 additions & 3 deletions ratis-shell/src/main/java/org/apache/ratis/shell/cli/RaftUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,35 @@

import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.RaftClientConfigKeys;
import org.apache.ratis.conf.Parameters;
import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.protocol.RaftGroup;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.grpc.GrpcConfigKeys;
import org.apache.ratis.grpc.GrpcTlsConfig;
import org.apache.ratis.netty.NettyConfigKeys;
import org.apache.ratis.protocol.*;
import org.apache.ratis.protocol.exceptions.RaftException;
import org.apache.ratis.retry.ExponentialBackoffRetry;
import org.apache.ratis.rpc.RpcType;
import org.apache.ratis.rpc.SupportedRpcType;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.function.CheckedFunction;

import java.io.IOException;
import java.io.PrintStream;
import java.net.InetSocketAddress;
import java.util.Properties;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.function.*;
import java.util.stream.Collectors;

/**
* Helper class for raft operations.
*/
public final class RaftUtils {

public static final RaftGroupId DEFAULT_RAFT_GROUP_ID = RaftGroupId.randomId();


private RaftUtils() {
// prevent instantiation
}
Expand Down Expand Up @@ -65,6 +79,18 @@ public static RaftPeerId getPeerId(String host, int port) {
* @return return a raft client
*/
public static RaftClient createClient(RaftGroup raftGroup) {
return createClient(raftGroup, null, null);
}


/**
* Create a raft client to communicate to ratis server.
* @param raftGroup the raft group
* @param rpcType the rpcType
* @param tlsConfig the tlsConfig
* @return return a raft client
*/
public static RaftClient createClient(RaftGroup raftGroup, RpcType rpcType, GrpcTlsConfig tlsConfig) {
RaftProperties properties = new RaftProperties();
RaftClientConfigKeys.Rpc.setRequestTimeout(properties,
TimeDuration.valueOf(15, TimeUnit.SECONDS));
Expand All @@ -84,6 +110,139 @@ public static RaftClient createClient(RaftGroup raftGroup) {
.setRaftGroup(raftGroup)
.setProperties(properties)
.setRetryPolicy(retryPolicy)
.setParameters(setClientTlsConf(rpcType, tlsConfig))
.build();
}

/**
* Execute a given function with input parameter from the members of a list.
*
* @param list the input parameters
* @param function the function to be executed
* @param <T> parameter type
* @param <K> return value type
* @param <E> the exception type thrown by the given function.
* @return the value returned by the given function.
*/
public static <T, K, E extends Throwable> K runFunction(Collection<T> list, CheckedFunction<T, K, E> function) {
for (T t : list) {
try {
K ret = function.apply(t);
if (ret != null) {
return ret;
}
} catch (Throwable e) {
e.printStackTrace();
}
}
return null;
}


public static List<RaftPeer> buildRaftPeersFromStr(String peers) {
List<InetSocketAddress> addresses = new ArrayList<>();
String[] peersArray = peers.split(",");
for (String peer : peersArray) {
addresses.add(parseInetSocketAddress(peer));
}

return addresses.stream()
.map(addr -> RaftPeer.newBuilder()
.setId(RaftUtils.getPeerId(addr))
.setAddress(addr)
.build()
).collect(Collectors.toList());
}

public static RaftGroupId buildRaftGroupIdFromStr(String groupId) {
return (groupId != null && !groupId.equals("")) ? RaftGroupId.valueOf(UUID.fromString(groupId))
: DEFAULT_RAFT_GROUP_ID;
}

public static RaftGroupId retrieveRemoteGroupId(RaftGroupId raftGroupIdFromConfig,
List<RaftPeer> peers,
RaftClient client, PrintStream printStream) throws IOException {
RaftGroupId remoteGroupId;
if (raftGroupIdFromConfig != DEFAULT_RAFT_GROUP_ID) {
return raftGroupIdFromConfig;
} else {
final List<RaftGroupId> groupIds = runFunction(peers,
p -> client.getGroupManagementApi((p.getId())).list().getGroupIds());

if (groupIds == null) {
printStream.println("Failed to get group ID from " + peers);
throw new IOException("Failed to get group ID from " + peers);
} else if (groupIds.size() == 1) {
remoteGroupId = groupIds.get(0);
} else {
printStream.println("There are more than one groups, you should specific one. " + groupIds);
throw new IOException("There are more than one groups, you should specific one. " + groupIds);
}
}

return remoteGroupId;
}

public static GroupInfoReply retrieveGroupInfoByGroupId(RaftGroupId remoteGroupId, List<RaftPeer> peers, RaftClient client, PrintStream printStream)
throws IOException {
GroupInfoReply groupInfoReply = runFunction(peers, p -> client.getGroupManagementApi((p.getId())).info(remoteGroupId));
processReply(groupInfoReply,
printStream::println, "Failed to get group info for group id " + remoteGroupId.getUuid() + " from " + peers);
return groupInfoReply;
}

public static void processReply(RaftClientReply reply, Consumer<String> printer, String message) throws IOException {
processReplyInternal(reply, () -> printer.accept(message));
}

private static void processReplyInternal(RaftClientReply reply, Runnable printer) throws IOException {
if (reply == null || !reply.isSuccess()) {
final RaftException e = Optional.ofNullable(reply)
.map(RaftClientReply::getException)
.orElseGet(() -> new RaftException("Reply: " + reply));
printer.run();
throw new IOException(e.getMessage(), e);
}
}

public static InetSocketAddress parseInetSocketAddress(String address) {
try {
final String[] hostPortPair = address.split(":");
if (hostPortPair.length < 2) {
throw new IllegalArgumentException("Unexpected address format <HOST:PORT>.");
}
return new InetSocketAddress(hostPortPair[0], Integer.parseInt(hostPortPair[1]));
} catch (Exception e) {
throw new IllegalArgumentException("Failed to parse the server address parameter \"" + address + "\".", e);
}
}

public static Parameters setClientTlsConf(RpcType rpcType,
GrpcTlsConfig tlsConfig) {
// TODO: GRPC TLS only for now, netty/hadoop RPC TLS support later.
if (tlsConfig != null && rpcType == SupportedRpcType.GRPC) {
Parameters parameters = new Parameters();
setAdminTlsConf(parameters, tlsConfig);
setClientTlsConf(parameters, tlsConfig);
return parameters;
}
return null;
}

private static void setAdminTlsConf(Parameters parameters,
GrpcTlsConfig tlsConfig) {
if (tlsConfig != null) {
GrpcConfigKeys.Admin.setTlsConf(parameters, tlsConfig);
}
}

private static void setClientTlsConf(Parameters parameters,
GrpcTlsConfig tlsConfig) {
if (tlsConfig != null) {
GrpcConfigKeys.Client.setTlsConf(parameters, tlsConfig);
NettyConfigKeys.DataStream.Client.setTlsConf(parameters, tlsConfig);
}
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package org.apache.ratis.shell.cli;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import java.io.*;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.Optional;
import java.util.function.Supplier;

public class SecurityUtils {
static Logger LOG = LoggerFactory.getLogger(SecurityUtils.class);

public static KeyStore getTrustStore()
throws Exception {
X509Certificate[] certificate = getCertificate("ssl/ca.crt");

// build trustStore
KeyStore trustStore = KeyStore.getInstance(KeyStore.getDefaultType());
trustStore.load(null, null);

for (X509Certificate cert: certificate) {
trustStore.setCertificateEntry(cert.getSerialNumber().toString(), cert);
}
return trustStore;
}

public static X509TrustManager getTrustManager(KeyStore keyStore) throws Exception{
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(
TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(keyStore);
TrustManager[] trustManagers = trustManagerFactory.getTrustManagers();
if (trustManagers.length != 1 || !(trustManagers[0] instanceof X509TrustManager)) {
throw new IllegalStateException("Unexpected default trust managers:"
+ Arrays.toString(trustManagers));
}
return (X509TrustManager) trustManagers[0];
}

static X509Certificate[] getCertificate(String certPath)
throws CertificateException, IOException {
// Read certificates
X509Certificate[] certificate = new X509Certificate[1];
CertificateFactory fact = CertificateFactory.getInstance("X.509");
try (InputStream is = Files.newInputStream(Paths.get(certPath))) {
certificate[0] = (X509Certificate) fact.generateCertificate(is);
}
return certificate;
}

}
Loading

0 comments on commit ca41aca

Please sign in to comment.