Skip to content

Commit

Permalink
Add server tls config
Browse files Browse the repository at this point in the history
  • Loading branch information
anandkrshaw committed Jul 23, 2024
1 parent 3545812 commit a7431c6
Showing 1 changed file with 18 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.grpc.ServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
Expand Down Expand Up @@ -64,7 +65,7 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final int _port;
private final QueryRunner _queryRunner;
Expand Down Expand Up @@ -96,7 +97,7 @@ public void start() {
throw new RuntimeException("Failed to start QueryServer", e);
}
} else {
_server = NettyServerBuilder.forPort(_port).addService(this).sslContext(buildSslContext(_tlsConfig))
_server = NettyServerBuilder.forPort(_port).addService(this).sslContext(buildGRpcSslContext(_tlsConfig))
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
}
LOGGER.info("Initialized QueryServer on port: {}", _port);
Expand Down Expand Up @@ -212,23 +213,25 @@ public void cancel(Worker.CancelRequest request, StreamObserver<Worker.CancelRes
responseObserver.onCompleted();
}

private SslContext buildSslContext(TlsConfig tlsConfig) {
private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
}
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory = RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig,
PinotInsecureMode::isPinotInInsecureMode);
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
sslFactory.getKeyManagerFactory().ifPresent(sslContextBuilder::keyManager);
SSLFactory sslFactory =
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(
tlsConfig, PinotInsecureMode::isPinotInInsecureMode);
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.getSslProvider() != null) {
sslContextBuilder =
GrpcSslContexts.configure(sslContextBuilder, SslProvider.valueOf(tlsConfig.getSslProvider()));
} else {
sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return sslContextBuilder.build();
} catch (SSLException e) {
return GrpcSslContexts.configure(sslContextBuilder).build();
} catch (Exception e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
});
Expand Down

0 comments on commit a7431c6

Please sign in to comment.