diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java index 9987876b9580..808e8d0e024f 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/grpc/GrpcQueryClient.java @@ -84,7 +84,7 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) { _channelShutdownTimeoutSeconds = config.getChannelShutdownTimeoutSecond(); } - private SslContext buildSslContext(TlsConfig tlsConfig) { + public static SslContext buildSslContext(TlsConfig tlsConfig) { LOGGER.info("Building gRPC SSL context"); SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> { try { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java index daae6d74cca0..4d8608c5eaa1 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/grpc/GrpcQueryServer.java @@ -119,7 +119,7 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx ResourceManager.DEFAULT_QUERY_WORKER_THREADS); } - private SslContext buildGRpcSslContext(TlsConfig tlsConfig) + public static SslContext buildGRpcSslContext(TlsConfig tlsConfig) throws IllegalArgumentException { LOGGER.info("Building gRPC SSL context"); if (tlsConfig.getKeyStorePath() == null) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java index 5b036930ce09..cea23218cd44 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/DispatchClient.java @@ -21,10 +21,14 @@ import io.grpc.Deadline; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder; import io.grpc.stub.StreamObserver; +import java.util.Collections; import java.util.function.Consumer; +import org.apache.pinot.common.config.GrpcConfig; import org.apache.pinot.common.proto.PinotQueryWorkerGrpc; import org.apache.pinot.common.proto.Worker; +import org.apache.pinot.common.utils.grpc.GrpcQueryClient; import org.apache.pinot.query.routing.QueryServerInstance; @@ -41,7 +45,17 @@ class DispatchClient { private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub; public DispatchClient(String host, int port) { - _channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); + this(host, port, new GrpcConfig(Collections.emptyMap())); + } + + public DispatchClient(String host, int port, GrpcConfig grpcConfig) { + if (grpcConfig.isUsePlainText()) { + _channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build(); + } else { + _channel = + NettyChannelBuilder.forAddress(host, port) + .sslContext(GrpcQueryClient.buildSslContext(grpcConfig.getTlsConfig())).build(); + } _dispatchStub = PinotQueryWorkerGrpc.newStub(_channel); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java index 763192e16ed8..65e4ca7df077 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/service/server/QueryServer.java @@ -20,6 +20,7 @@ import io.grpc.Server; import io.grpc.ServerBuilder; +import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder; import io.grpc.stub.StreamObserver; import java.util.List; import java.util.Map; @@ -27,10 +28,12 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.exception.QueryException; import org.apache.pinot.common.proto.PinotQueryWorkerGrpc; import org.apache.pinot.common.proto.Worker; import org.apache.pinot.common.utils.NamedThreadFactory; +import org.apache.pinot.core.transport.grpc.GrpcQueryServer; import org.apache.pinot.query.routing.QueryPlanSerDeUtils; import org.apache.pinot.query.routing.StageMetadata; import org.apache.pinot.query.routing.StagePlan; @@ -55,6 +58,7 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase { private final int _port; private final QueryRunner _queryRunner; + private final TlsConfig _tlsConfig; // query submission service is only used for plan submission for now. // TODO: with complex query submission logic we should allow asynchronous query submission return instead of // directly return from submission response observer. @@ -62,9 +66,10 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase { private Server _server = null; - public QueryServer(int port, QueryRunner queryRunner) { + public QueryServer(int port, QueryRunner queryRunner, TlsConfig tlsConfig) { _port = port; _queryRunner = queryRunner; + _tlsConfig = tlsConfig; _querySubmissionExecutorService = Executors.newCachedThreadPool(new NamedThreadFactory("query_submission_executor_on_" + _port + "_port")); } @@ -73,7 +78,14 @@ public void start() { LOGGER.info("Starting QueryServer"); try { if (_server == null) { - _server = ServerBuilder.forPort(_port).addService(this).maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build(); + if (_tlsConfig == null) { + _server = ServerBuilder.forPort(_port).addService(this) + .maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build(); + } else { + _server = NettyServerBuilder.forPort(_port).addService(this) + .sslContext(GrpcQueryServer.buildGRpcSslContext(_tlsConfig)) + .maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build(); + } LOGGER.info("Initialized QueryServer on port: {}", _port); } _queryRunner.start(); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java index 694fb3c08779..c21c40b2d9db 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/dispatch/QueryDispatcherTest.java @@ -60,7 +60,7 @@ public void setUp() for (int i = 0; i < QUERY_SERVER_COUNT; i++) { int availablePort = QueryTestUtils.getAvailablePort(); QueryRunner queryRunner = Mockito.mock(QueryRunner.class); - QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner)); + QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner, null)); queryServer.start(); _queryServerMap.put(availablePort, queryServer); } diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java index 3a0b23408eab..7a14a2a4c64d 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/service/server/QueryServerTest.java @@ -76,7 +76,7 @@ public void setUp() for (int i = 0; i < QUERY_SERVER_COUNT; i++) { int availablePort = QueryTestUtils.getAvailablePort(); QueryRunner queryRunner = mock(QueryRunner.class); - QueryServer queryServer = new QueryServer(availablePort, queryRunner); + QueryServer queryServer = new QueryServer(availablePort, queryRunner, null); queryServer.start(); _queryServerMap.put(availablePort, queryServer); _queryRunnerMap.put(availablePort, queryRunner); diff --git a/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java b/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java index 2a75ca7f5aba..01f4402710dd 100644 --- a/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java +++ b/pinot-server/src/main/java/org/apache/pinot/server/starter/ServerInstance.java @@ -119,15 +119,15 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo TlsUtils.extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_TLS_PREFIX); NettyConfig nettyConfig = NettyConfig.extractNettyConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_NETTY_PREFIX); - accessControlFactory - .init(serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL), - helixManager); + accessControlFactory.init( + serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL), helixManager); _accessControl = accessControlFactory.create(); if (serverConf.isMultiStageServerEnabled()) { LOGGER.info("Initializing Multi-stage query engine"); - _workerQueryServer = new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager, - _serverMetrics); + _workerQueryServer = + new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager, _serverMetrics, + serverConf.isNettyTlsServerEnabled() ? tlsConfig : null); } else { _workerQueryServer = null; } @@ -135,9 +135,9 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo if (serverConf.isNettyServerEnabled()) { int nettyPort = serverConf.getNettyPort(); LOGGER.info("Initializing Netty query server on port: {}", nettyPort); - _instanceRequestHandler = ChannelHandlerFactory - .getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler, - _serverMetrics, new AllowAllAccessFactory().create()); + _instanceRequestHandler = + ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), + _queryScheduler, _serverMetrics, new AllowAllAccessFactory().create()); _nettyQueryServer = new QueryServer(nettyPort, nettyConfig, _instanceRequestHandler); } else { _nettyQueryServer = null; @@ -146,9 +146,9 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo if (serverConf.isNettyTlsServerEnabled()) { int nettySecPort = serverConf.getNettyTlsPort(); LOGGER.info("Initializing TLS-secured Netty query server on port: {}", nettySecPort); - _instanceRequestHandler = ChannelHandlerFactory - .getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler, - _serverMetrics, _accessControl); + _instanceRequestHandler = + ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), + _queryScheduler, _serverMetrics, _accessControl); _nettyTlsQueryServer = new QueryServer(nettySecPort, nettyConfig, tlsConfig, _instanceRequestHandler); } else { _nettyTlsQueryServer = null; @@ -157,9 +157,8 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo int grpcPort = serverConf.getGrpcPort(); LOGGER.info("Initializing gRPC query server on port: {}", grpcPort); _grpcQueryServer = new GrpcQueryServer(grpcPort, GrpcConfig.buildGrpcQueryConfig(serverConf.getPinotConfig()), - serverConf.isGrpcTlsServerEnabled() ? TlsUtils - .extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, - _queryExecutor, _serverMetrics, _accessControl); + serverConf.isGrpcTlsServerEnabled() ? TlsUtils.extractTlsConfig(serverConf.getPinotConfig(), + CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, _queryExecutor, _serverMetrics, _accessControl); } else { _grpcQueryServer = null; } diff --git a/pinot-server/src/main/java/org/apache/pinot/server/worker/WorkerQueryServer.java b/pinot-server/src/main/java/org/apache/pinot/server/worker/WorkerQueryServer.java index 45db3208ec3b..542cc9bd90eb 100644 --- a/pinot-server/src/main/java/org/apache/pinot/server/worker/WorkerQueryServer.java +++ b/pinot-server/src/main/java/org/apache/pinot/server/worker/WorkerQueryServer.java @@ -19,6 +19,7 @@ package org.apache.pinot.server.worker; import org.apache.helix.HelixManager; +import org.apache.pinot.common.config.TlsConfig; import org.apache.pinot.common.metrics.ServerMetrics; import org.apache.pinot.core.data.manager.InstanceDataManager; import org.apache.pinot.query.runtime.QueryRunner; @@ -37,19 +38,20 @@ public class WorkerQueryServer { private QueryRunner _queryRunner; private InstanceDataManager _instanceDataManager; private ServerMetrics _serverMetrics; + private TlsConfig _tlsConfig; public WorkerQueryServer(PinotConfiguration configuration, InstanceDataManager instanceDataManager, - HelixManager helixManager, ServerMetrics serverMetrics) { + HelixManager helixManager, ServerMetrics serverMetrics, TlsConfig tlsConfig) { _configuration = toWorkerQueryConfig(configuration); _helixManager = helixManager; _instanceDataManager = instanceDataManager; + _tlsConfig = tlsConfig; _serverMetrics = serverMetrics; - _queryServicePort = - _configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT, - CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT); + _queryServicePort = _configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT, + CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT); _queryRunner = new QueryRunner(); _queryRunner.init(_configuration, _instanceDataManager, _helixManager, _serverMetrics); - _queryWorkerService = new QueryServer(_queryServicePort, _queryRunner); + _queryWorkerService = new QueryServer(_queryServicePort, _queryRunner, _tlsConfig); } private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configuration) { @@ -62,17 +64,15 @@ private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configu CommonConstants.Helix.SERVER_INSTANCE_PREFIX_LENGTH) : instanceId; newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME, hostname); } - int runnerPort = newConfig.getProperty( - CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT, + int runnerPort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT, CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_RUNNER_PORT); if (runnerPort == -1) { runnerPort = newConfig.getProperty(CommonConstants.Server.CONFIG_OF_GRPC_PORT, CommonConstants.Server.DEFAULT_GRPC_PORT); newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT, runnerPort); } - int servicePort = - newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT, - CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT); + int servicePort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT, + CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT); if (servicePort == -1) { servicePort = newConfig.getProperty(CommonConstants.Helix.KEY_OF_SERVER_NETTY_PORT, CommonConstants.Helix.DEFAULT_SERVER_NETTY_PORT);