diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/mailbox/channel/GrpcMailboxServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/mailbox/channel/GrpcMailboxServer.java index aade639f9846..a3eeb1e88981 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/mailbox/channel/GrpcMailboxServer.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/mailbox/channel/GrpcMailboxServer.java @@ -20,14 +20,31 @@ import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.ClientAuth; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; import io.grpc.stub.StreamObserver; import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; +import nl.altindag.ssl.SSLFactory; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.proto.Mailbox; import org.apache.pinot.common.proto.PinotMailboxGrpc; +import org.apache.pinot.common.utils.tls.PinotInsecureMode; +import org.apache.pinot.common.utils.tls.RenewableTlsUtils; +import org.apache.pinot.common.utils.tls.TlsUtils; import org.apache.pinot.query.mailbox.MailboxService; import org.apache.pinot.spi.env.PinotConfiguration; import org.apache.pinot.spi.utils.CommonConstants; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.pinot.spi.utils.CommonConstants.Server.CONFIG_OF_GRPCTLS_SERVER_ENABLED; /** @@ -37,17 +54,33 @@ * send by the sender of the sender/receiver pair. */ public class GrpcMailboxServer extends PinotMailboxGrpc.PinotMailboxImplBase { + private static final Logger LOGGER = LoggerFactory.getLogger(GrpcMailboxServer.class); private static final long DEFAULT_SHUTDOWN_TIMEOUT_MS = 10_000L; private final MailboxService _mailboxService; private final Server _server; + // the key is the hashCode of the TlsConfig, the value is the SslContext + // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the + // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost. + private static final Map SERVER_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>(); public GrpcMailboxServer(MailboxService mailboxService, PinotConfiguration config) { + TlsConfig tlsConfig = null; _mailboxService = mailboxService; + if (config.getProperty(CONFIG_OF_GRPCTLS_SERVER_ENABLED, false)) { + tlsConfig = TlsUtils.extractTlsConfig(config, CommonConstants.Server.SERVER_GRPCTLS_PREFIX); + } int port = mailboxService.getPort(); - _server = ServerBuilder.forPort(port).addService(this).maxInboundMessageSize( - config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES, - CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build(); + if (tlsConfig != null) { + _server = NettyServerBuilder.forPort(port).addService(this).sslContext(buildGRpcSslContext(tlsConfig)) + .maxInboundMessageSize( + config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES, + CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build(); + } else { + _server = ServerBuilder.forPort(port).addService(this).maxInboundMessageSize( + config.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES, + CommonConstants.MultiStageQueryRunner.DEFAULT_MAX_INBOUND_QUERY_DATA_BLOCK_SIZE_BYTES)).build(); + } } public void start() { @@ -66,6 +99,31 @@ public void shutdown() { } } + private SslContext buildGRpcSslContext(TlsConfig tlsConfig) + throws IllegalArgumentException { + LOGGER.info("Building gRPC SSL context"); + if (tlsConfig.getKeyStorePath() == null) { + throw new IllegalArgumentException("Must provide key store path for secured gRpc server"); + } + SslContext sslContext = SERVER_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> { + try { + SSLFactory sslFactory = + RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores( + tlsConfig, PinotInsecureMode::isPinotInInsecureMode); + SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(sslFactory.getKeyManagerFactory().get()) + .sslProvider(SslProvider.valueOf(tlsConfig.getSslProvider())); + sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager); + if (tlsConfig.isClientAuthEnabled()) { + sslContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + return GrpcSslContexts.configure(sslContextBuilder).build(); + } catch (Exception e) { + throw new RuntimeException("Failed to build gRPC SSL context", e); + } + }); + return sslContext; + } + @Override public StreamObserver open(StreamObserver responseObserver) { return new MailboxContentObserver(_mailboxService, responseObserver); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java index 5b036930ce09..8e834e4c7083 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java @@ -21,11 +21,27 @@ import io.grpc.Deadline; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContext; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslContextBuilder; +import io.grpc.netty.shaded.io.netty.handler.ssl.SslProvider; import io.grpc.stub.StreamObserver; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Consumer; +import javax.net.ssl.SSLException; +import nl.altindag.ssl.SSLFactory; +import org.apache.pinot.common.config.GrpcConfig; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.proto.PinotQueryWorkerGrpc; import org.apache.pinot.common.proto.Worker; +import org.apache.pinot.common.utils.tls.PinotInsecureMode; +import org.apache.pinot.common.utils.tls.RenewableTlsUtils; import org.apache.pinot.query.routing.QueryServerInstance; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** @@ -36,12 +52,28 @@ */ class DispatchClient { private static final StreamObserver NO_OP_CANCEL_STREAM_OBSERVER = new CancelObserver(); + private static final Logger LOGGER = LoggerFactory.getLogger(DispatchClient.class); + // the key is the hashCode of the TlsConfig, the value is the SslContext + // We don't use TlsConfig as the map key because the TlsConfig is mutable, which means the hashCode can change. If the + // hashCode changes and the map is resized, the SslContext of the old hashCode will be lost. + private static final Map CLIENT_SSL_CONTEXTS_CACHE = new ConcurrentHashMap<>(); private final ManagedChannel _channel; private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub; public DispatchClient(String host, int port) { - _channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); + this(host, port, new GrpcConfig(Collections.emptyMap())); + } + public DispatchClient(String host, int port, GrpcConfig grpcConfig) { + if (grpcConfig.isUsePlainText()) { + _channel = + ManagedChannelBuilder.forAddress(host, port).maxInboundMessageSize(grpcConfig.getMaxInboundMessageSizeBytes()) + .usePlaintext().build(); + } else { + _channel = + NettyChannelBuilder.forAddress(host, port).maxInboundMessageSize(grpcConfig.getMaxInboundMessageSizeBytes()) + .sslContext(buildSslContext(grpcConfig.getTlsConfig())).build(); + } _dispatchStub = PinotQueryWorkerGrpc.newStub(_channel); } @@ -58,4 +90,26 @@ public void cancel(long requestId) { Worker.CancelRequest cancelRequest = Worker.CancelRequest.newBuilder().setRequestId(requestId).build(); _dispatchStub.cancel(cancelRequest, NO_OP_CANCEL_STREAM_OBSERVER); } + private SslContext buildSslContext(TlsConfig tlsConfig) { + LOGGER.info("Building gRPC SSL context"); + SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> { + try { + SSLFactory sslFactory = RenewableTlsUtils.createSSLFactoryAndEnableAutoRenewalWhenUsingFileStores(tlsConfig, + PinotInsecureMode::isPinotInInsecureMode); + SslContextBuilder sslContextBuilder = SslContextBuilder.forClient(); + sslFactory.getKeyManagerFactory().ifPresent(sslContextBuilder::keyManager); + sslFactory.getTrustManagerFactory().ifPresent(sslContextBuilder::trustManager); + if (tlsConfig.getSslProvider() != null) { + sslContextBuilder = + GrpcSslContexts.configure(sslContextBuilder, SslProvider.valueOf(tlsConfig.getSslProvider())); + } else { + sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); + } + return sslContextBuilder.build(); + } catch (SSLException e) { + throw new RuntimeException("Failed to build gRPC SSL context", e); + } + }); + return sslContext; + } }