Skip to content

Commit

Permalink
Add TLS support to Dispacth client
Browse files Browse the repository at this point in the history
  • Loading branch information
anandkrshaw committed Jul 24, 2024
1 parent 8731ddb commit ddf4338
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,31 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import nl.altindag.ssl.SSLFactory;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.proto.Mailbox;
import org.apache.pinot.common.proto.PinotMailboxGrpc;
import org.apache.pinot.common.utils.tls.PinotInsecureMode;
import org.apache.pinot.common.utils.tls.RenewableTlsUtils;
import org.apache.pinot.common.utils.tls.TlsUtils;
import org.apache.pinot.query.mailbox.MailboxService;
import org.apache.pinot.spi.env.PinotConfiguration;
import org.apache.pinot.spi.utils.CommonConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.pinot.spi.utils.CommonConstants.Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED;


/**
Expand All @@ -37,17 +54,33 @@
* send by the sender of the sender/receiver pair.
*/
public class GrpcMailboxServer extends PinotMailboxGrpc.PinotMailboxImplBase {
private static final Logger LOGGER = LoggerFactory.getLogger(GrpcMailboxServer.class);
private static final long DEFAULT_SHUTDOWN_TIMEOUT_MS = 10_000L;

private final MailboxService _mailboxService;
private final Server _server;
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

public GrpcMailboxServer(MailboxService mailboxService, PinotConfiguration config) {
TlsConfig tlsConfig = null;
_mailboxService = mailboxService;
if (config.getProperty(CONFIG_OF_GRPCTLS_SERVER_ENABLED, false)) {
tlsConfig = TlsUtils.extractTlsConfig(config, CommonConstants.Server.SERVER_GRPCTLS_PREFIX);
}
int port = mailboxService.getPort();
_server = ServerBuilder.forPort(port).addService(this).maxInboundMessageSize(
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build();
if (tlsConfig != null) {
_server = NettyServerBuilder.forPort(port).addService(this).sslContext(buildGRpcSslContext(tlsConfig))
.maxInboundMessageSize(
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build();
} else {
_server = ServerBuilder.forPort(port).addService(this).maxInboundMessageSize(
config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES,
CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build();
}
}

public void start() {
Expand All @@ -66,6 +99,31 @@ public void shutdown() {
}
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
throw new IllegalArgumentException("Must provide key store path for secured gRpc server");
}
SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory =
RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(
tlsConfig, PinotInsecureMode::isPinotInInsecureMode);
SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get())
.sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider()));
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.isClientAuthEnabled()) {
sslContextBuilder.clientAuth(ClientAuth.REQUIRE);
}
return GrpcSslContexts.configure(sslContextBuilder).build();
} catch (Exception e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
});
return sslContext;
}

@Override
public StreamObserver<Mailbox.MailboxContent> open(StreamObserver<Mailbox.MailboxStatus> responseObserver) {
return new MailboxContentObserver(_mailboxService, responseObserver);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,27 @@
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder;
import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import javax.net.ssl.SSLException;
import nl.altindag.ssl.SSLFactory;
import org.apache.pinot.common.config.GrpcConfig;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.utils.tls.PinotInsecureMode;
import org.apache.pinot.common.utils.tls.RenewableTlsUtils;
import org.apache.pinot.query.routing.QueryServerInstance;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
Expand All @@ -36,12 +52,28 @@
*/
class DispatchClient {
private static final StreamObserver<Worker.CancelResponse> NO_OP_CANCEL_STREAM_OBSERVER = new CancelObserver();
private static final Logger LOGGER = LoggerFactory.getLogger(DispatchClient.class);
// the key is the hashCode of the TlsConfig, the value is the SslContext
// We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the
// hashCode changes and the map is resized, the SslContext of the old hashCode will be lost.
private static final Map<Integer, SslContext> CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>();

private final ManagedChannel _channel;
private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;

public DispatchClient(String host, int port) {
_channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
this(host, port, new GrpcConfig(Collections.emptyMap()));
}
public DispatchClient(String host, int port, GrpcConfig grpcConfig) {
if (grpcConfig.isUsePlainText()) {
_channel =
ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(grpcConfig.getMaxInboundMessageSizeBytes())
.usePlaintext().build();
} else {
_channel =
NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(grpcConfig.getMaxInboundMessageSizeBytes())
.sslContext(buildSslContext(grpcConfig.getTlsConfig())).build();
}
_dispatchStub = PinotQueryWorkerGrpc.newStub(_channel);
}

Expand All @@ -58,4 +90,26 @@ public void cancel(long requestId) {
Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build();
_dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER);
}
private SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
SSLFactory sslFactory = RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig,
PinotInsecureMode::isPinotInInsecureMode);
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();
sslFactory.getKeyManagerFactory().ifPresent(sslContextBuilder::keyManager);
sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager);
if (tlsConfig.getSslProvider() != null) {
sslContextBuilder =
GrpcSslContexts.configure(sslContextBuilder, SslProvider.valueOf(tlsConfig.getSslProvider()));
} else {
sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder);
}
return sslContextBuilder.build();
} catch (SSLException e) {
throw new RuntimeException("Failed to build gRPC SSL context", e);
}
});
return sslContext;
}
}

0 comments on commit ddf4338

Please sign in to comment.