Skip to content

Commit

Permalink
TLS Configuration Support for QueryServer and Dispatch Client (#13645)
Browse files Browse the repository at this point in the history
  • Loading branch information
anandheritage authored Sep 18, 2024
1 parent c565a83 commit 21ff6bf
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
_channelShutdownTimeoutSeconds = config.getChannelShutdownTimeoutSecond();
}

private SslContext buildSslContext(TlsConfig tlsConfig) {
public static SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
ResourceManager.DEFAULT_QUERY_WORKER_THREADS);
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
public static SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.function.Consumer;
import org.apache.pinot.common.config.GrpcConfig;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.utils.grpc.GrpcQueryClient;
import org.apache.pinot.query.routing.QueryServerInstance;


Expand All @@ -41,7 +45,17 @@ class DispatchClient {
private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;

public DispatchClient(String host, int port) {
_channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
this(host, port, new GrpcConfig(Collections.emptyMap()));
}

public DispatchClient(String host, int port, GrpcConfig grpcConfig) {
if (grpcConfig.isUsePlainText()) {
_channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
} else {
_channel =
NettyChannelBuilder.forAddress(host, port)
.sslContext(GrpcQueryClient.buildSslContext(grpcConfig.getTlsConfig())).build();
}
_dispatchStub = PinotQueryWorkerGrpc.newStub(_channel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.exception.QueryException;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.utils.NamedThreadFactory;
import org.apache.pinot.core.transport.grpc.GrpcQueryServer;
import org.apache.pinot.query.routing.QueryPlanSerDeUtils;
import org.apache.pinot.query.routing.StageMetadata;
import org.apache.pinot.query.routing.StagePlan;
Expand All @@ -55,16 +58,18 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {

private final int _port;
private final QueryRunner _queryRunner;
private final TlsConfig _tlsConfig;
// query submission service is only used for plan submission for now.
// TODO: with complex query submission logic we should allow asynchronous query submission return instead of
// directly return from submission response observer.
private final ExecutorService _querySubmissionExecutorService;

private Server _server = null;

public QueryServer(int port, QueryRunner queryRunner) {
public QueryServer(int port, QueryRunner queryRunner, TlsConfig tlsConfig) {
_port = port;
_queryRunner = queryRunner;
_tlsConfig = tlsConfig;
_querySubmissionExecutorService =
Executors.newCachedThreadPool(new NamedThreadFactory("query_submission_executor_on_" + _port + "_port"));
}
Expand All @@ -73,7 +78,14 @@ public void start() {
LOGGER.info("Starting QueryServer");
try {
if (_server == null) {
_server = ServerBuilder.forPort(_port).addService(this).maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
if (_tlsConfig == null) {
_server = ServerBuilder.forPort(_port).addService(this)
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
} else {
_server = NettyServerBuilder.forPort(_port).addService(this)
.sslContext(GrpcQueryServer.buildGRpcSslContext(_tlsConfig))
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
}
LOGGER.info("Initialized QueryServer on port: {}", _port);
}
_queryRunner.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void setUp()
for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
int availablePort = QueryTestUtils.getAvailablePort();
QueryRunner queryRunner = Mockito.mock(QueryRunner.class);
QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner));
QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner, null));
queryServer.start();
_queryServerMap.put(availablePort, queryServer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void setUp()
for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
int availablePort = QueryTestUtils.getAvailablePort();
QueryRunner queryRunner = mock(QueryRunner.class);
QueryServer queryServer = new QueryServer(availablePort, queryRunner);
QueryServer queryServer = new QueryServer(availablePort, queryRunner, null);
queryServer.start();
_queryServerMap.put(availablePort, queryServer);
_queryRunnerMap.put(availablePort, queryRunner);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,25 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
TlsUtils.extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_TLS_PREFIX);
NettyConfig nettyConfig =
NettyConfig.extractNettyConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_NETTY_PREFIX);
accessControlFactory
.init(serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL),
helixManager);
accessControlFactory.init(
serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL), helixManager);
_accessControl = accessControlFactory.create();

if (serverConf.isMultiStageServerEnabled()) {
LOGGER.info("Initializing Multi-stage query engine");
_workerQueryServer = new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager,
_serverMetrics);
_workerQueryServer =
new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager, _serverMetrics,
serverConf.isNettyTlsServerEnabled() ? tlsConfig : null);
} else {
_workerQueryServer = null;
}

if (serverConf.isNettyServerEnabled()) {
int nettyPort = serverConf.getNettyPort();
LOGGER.info("Initializing Netty query server on port: {}", nettyPort);
_instanceRequestHandler = ChannelHandlerFactory
.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler,
_serverMetrics, new AllowAllAccessFactory().create());
_instanceRequestHandler =
ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(),
_queryScheduler, _serverMetrics, new AllowAllAccessFactory().create());
_nettyQueryServer = new QueryServer(nettyPort, nettyConfig, _instanceRequestHandler);
} else {
_nettyQueryServer = null;
Expand All @@ -146,9 +146,9 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
if (serverConf.isNettyTlsServerEnabled()) {
int nettySecPort = serverConf.getNettyTlsPort();
LOGGER.info("Initializing TLS-secured Netty query server on port: {}", nettySecPort);
_instanceRequestHandler = ChannelHandlerFactory
.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler,
_serverMetrics, _accessControl);
_instanceRequestHandler =
ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(),
_queryScheduler, _serverMetrics, _accessControl);
_nettyTlsQueryServer = new QueryServer(nettySecPort, nettyConfig, tlsConfig, _instanceRequestHandler);
} else {
_nettyTlsQueryServer = null;
Expand All @@ -157,9 +157,8 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
int grpcPort = serverConf.getGrpcPort();
LOGGER.info("Initializing gRPC query server on port: {}", grpcPort);
_grpcQueryServer = new GrpcQueryServer(grpcPort, GrpcConfig.buildGrpcQueryConfig(serverConf.getPinotConfig()),
serverConf.isGrpcTlsServerEnabled() ? TlsUtils
.extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null,
_queryExecutor, _serverMetrics, _accessControl);
serverConf.isGrpcTlsServerEnabled() ? TlsUtils.extractTlsConfig(serverConf.getPinotConfig(),
CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, _queryExecutor, _serverMetrics, _accessControl);
} else {
_grpcQueryServer = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.pinot.server.worker;

import org.apache.helix.HelixManager;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.core.data.manager.InstanceDataManager;
import org.apache.pinot.query.runtime.QueryRunner;
Expand All @@ -37,19 +38,20 @@ public class WorkerQueryServer {
private QueryRunner _queryRunner;
private InstanceDataManager _instanceDataManager;
private ServerMetrics _serverMetrics;
private TlsConfig _tlsConfig;

public WorkerQueryServer(PinotConfiguration configuration, InstanceDataManager instanceDataManager,
HelixManager helixManager, ServerMetrics serverMetrics) {
HelixManager helixManager, ServerMetrics serverMetrics, TlsConfig tlsConfig) {
_configuration = toWorkerQueryConfig(configuration);
_helixManager = helixManager;
_instanceDataManager = instanceDataManager;
_tlsConfig = tlsConfig;
_serverMetrics = serverMetrics;
_queryServicePort =
_configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
_queryServicePort = _configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
_queryRunner = new QueryRunner();
_queryRunner.init(_configuration, _instanceDataManager, _helixManager, _serverMetrics);
_queryWorkerService = new QueryServer(_queryServicePort, _queryRunner);
_queryWorkerService = new QueryServer(_queryServicePort, _queryRunner, _tlsConfig);
}

private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configuration) {
Expand All @@ -62,17 +64,15 @@ private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configu
CommonConstants.Helix.SERVER_INSTANCE_PREFIX_LENGTH) : instanceId;
newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME, hostname);
}
int runnerPort = newConfig.getProperty(
CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT,
int runnerPort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_RUNNER_PORT);
if (runnerPort == -1) {
runnerPort =
newConfig.getProperty(CommonConstants.Server.CONFIG_OF_GRPC_PORT, CommonConstants.Server.DEFAULT_GRPC_PORT);
newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT, runnerPort);
}
int servicePort =
newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
int servicePort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
if (servicePort == -1) {
servicePort = newConfig.getProperty(CommonConstants.Helix.KEY_OF_SERVER_NETTY_PORT,
CommonConstants.Helix.DEFAULT_SERVER_NETTY_PORT);
Expand Down

0 comments on commit 21ff6bf

Please sign in to comment.