Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TLS Configuration Support for QueryServer and Dispatch Client #13645

Merged
merged 15 commits into from
Sep 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public GrpcQueryClient(String host, int port, GrpcConfig config) {
_channelShutdownTimeoutSeconds = config.getChannelShutdownTimeoutSecond();
}

private SslContext buildSslContext(TlsConfig tlsConfig) {
public static SslContext buildSslContext(TlsConfig tlsConfig) {
LOGGER.info("Building gRPC SSL context");
SslContext sslContext = CLIENT_SSL_CONTEXTS_CACHE.computeIfAbsent(tlsConfig.hashCode(), tlsConfigHashCode -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public GrpcQueryServer(int port, GrpcConfig config, TlsConfig tlsConfig, QueryEx
ResourceManager.DEFAULT_QUERY_WORKER_THREADS);
}

private SslContext buildGRpcSslContext(TlsConfig tlsConfig)
public static SslContext buildGRpcSslContext(TlsConfig tlsConfig)
throws IllegalArgumentException {
LOGGER.info("Building gRPC SSL context");
if (tlsConfig.getKeyStorePath() == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
import io.grpc.Deadline;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.StreamObserver;
import java.util.Collections;
import java.util.function.Consumer;
import org.apache.pinot.common.config.GrpcConfig;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.utils.grpc.GrpcQueryClient;
import org.apache.pinot.query.routing.QueryServerInstance;


Expand All @@ -41,7 +45,17 @@ class DispatchClient {
private final PinotQueryWorkerGrpc.PinotQueryWorkerStub _dispatchStub;

public DispatchClient(String host, int port) {
_channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
this(host, port, new GrpcConfig(Collections.emptyMap()));
}

public DispatchClient(String host, int port, GrpcConfig grpcConfig) {
Copy link
Contributor

@soumitra-st soumitra-st Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constructor is only used in the other constructor to create non-tls DispatchClient. Does it mean that if gRPC TLS is enabled on the server side, then DispatchClient will always fail? I don't see a place where TLS enabled DispatchClient is created. cc: @gortiz @Jackie-Jiang

if (grpcConfig.isUsePlainText()) {
_channel = ManagedChannelBuilder.forAddress(host, port).usePlaintext().build();
} else {
_channel =
NettyChannelBuilder.forAddress(host, port)
.sslContext(GrpcQueryClient.buildSslContext(grpcConfig.getTlsConfig())).build();
}
_dispatchStub = PinotQueryWorkerGrpc.newStub(_channel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@

import io.grpc.Server;
import io.grpc.ServerBuilder;
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.exception.QueryException;
import org.apache.pinot.common.proto.PinotQueryWorkerGrpc;
import org.apache.pinot.common.proto.Worker;
import org.apache.pinot.common.utils.NamedThreadFactory;
import org.apache.pinot.core.transport.grpc.GrpcQueryServer;
import org.apache.pinot.query.routing.QueryPlanSerDeUtils;
import org.apache.pinot.query.routing.StageMetadata;
import org.apache.pinot.query.routing.StagePlan;
Expand All @@ -55,16 +58,18 @@ public class QueryServer extends PinotQueryWorkerGrpc.PinotQueryWorkerImplBase {

private final int _port;
private final QueryRunner _queryRunner;
private final TlsConfig _tlsConfig;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nullable annotation is needed here

// query submission service is only used for plan submission for now.
// TODO: with complex query submission logic we should allow asynchronous query submission return instead of
// directly return from submission response observer.
private final ExecutorService _querySubmissionExecutorService;

private Server _server = null;

public QueryServer(int port, QueryRunner queryRunner) {
public QueryServer(int port, QueryRunner queryRunner, TlsConfig tlsConfig) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nullabe annotation is needed here

_port = port;
_queryRunner = queryRunner;
_tlsConfig = tlsConfig;
_querySubmissionExecutorService =
Executors.newCachedThreadPool(new NamedThreadFactory("query_submission_executor_on_" + _port + "_port"));
}
Expand All @@ -73,7 +78,14 @@ public void start() {
LOGGER.info("Starting QueryServer");
try {
if (_server == null) {
_server = ServerBuilder.forPort(_port).addService(this).maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
if (_tlsConfig == null) {
_server = ServerBuilder.forPort(_port).addService(this)
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
} else {
_server = NettyServerBuilder.forPort(_port).addService(this)
.sslContext(GrpcQueryServer.buildGRpcSslContext(_tlsConfig))
.maxInboundMessageSize(MAX_INBOUND_MESSAGE_SIZE).build();
}
LOGGER.info("Initialized QueryServer on port: {}", _port);
}
_queryRunner.start();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public void setUp()
for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
int availablePort = QueryTestUtils.getAvailablePort();
QueryRunner queryRunner = Mockito.mock(QueryRunner.class);
QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner));
QueryServer queryServer = Mockito.spy(new QueryServer(availablePort, queryRunner, null));
queryServer.start();
_queryServerMap.put(availablePort, queryServer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public void setUp()
for (int i = 0; i < QUERY_SERVER_COUNT; i++) {
int availablePort = QueryTestUtils.getAvailablePort();
QueryRunner queryRunner = mock(QueryRunner.class);
QueryServer queryServer = new QueryServer(availablePort, queryRunner);
QueryServer queryServer = new QueryServer(availablePort, queryRunner, null);
queryServer.start();
_queryServerMap.put(availablePort, queryServer);
_queryRunnerMap.put(availablePort, queryRunner);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,25 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
TlsUtils.extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_TLS_PREFIX);
NettyConfig nettyConfig =
NettyConfig.extractNettyConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_NETTY_PREFIX);
accessControlFactory
.init(serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL),
helixManager);
accessControlFactory.init(
serverConf.getPinotConfig().subset(CommonConstants.Server.PREFIX_OF_CONFIG_OF_ACCESS_CONTROL), helixManager);
_accessControl = accessControlFactory.create();

if (serverConf.isMultiStageServerEnabled()) {
LOGGER.info("Initializing Multi-stage query engine");
_workerQueryServer = new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager,
_serverMetrics);
_workerQueryServer =
new WorkerQueryServer(serverConf.getPinotConfig(), _instanceDataManager, helixManager, _serverMetrics,
serverConf.isNettyTlsServerEnabled() ? tlsConfig : null);
} else {
_workerQueryServer = null;
}

if (serverConf.isNettyServerEnabled()) {
int nettyPort = serverConf.getNettyPort();
LOGGER.info("Initializing Netty query server on port: {}", nettyPort);
_instanceRequestHandler = ChannelHandlerFactory
.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler,
_serverMetrics, new AllowAllAccessFactory().create());
_instanceRequestHandler =
ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(),
_queryScheduler, _serverMetrics, new AllowAllAccessFactory().create());
_nettyQueryServer = new QueryServer(nettyPort, nettyConfig, _instanceRequestHandler);
} else {
_nettyQueryServer = null;
Expand All @@ -146,9 +146,9 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
if (serverConf.isNettyTlsServerEnabled()) {
int nettySecPort = serverConf.getNettyTlsPort();
LOGGER.info("Initializing TLS-secured Netty query server on port: {}", nettySecPort);
_instanceRequestHandler = ChannelHandlerFactory
.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(), _queryScheduler,
_serverMetrics, _accessControl);
_instanceRequestHandler =
ChannelHandlerFactory.getInstanceRequestHandler(helixManager.getInstanceName(), serverConf.getPinotConfig(),
_queryScheduler, _serverMetrics, _accessControl);
_nettyTlsQueryServer = new QueryServer(nettySecPort, nettyConfig, tlsConfig, _instanceRequestHandler);
} else {
_nettyTlsQueryServer = null;
Expand All @@ -157,9 +157,8 @@ public ServerInstance(ServerConf serverConf, HelixManager helixManager, AccessCo
int grpcPort = serverConf.getGrpcPort();
LOGGER.info("Initializing gRPC query server on port: {}", grpcPort);
_grpcQueryServer = new GrpcQueryServer(grpcPort, GrpcConfig.buildGrpcQueryConfig(serverConf.getPinotConfig()),
serverConf.isGrpcTlsServerEnabled() ? TlsUtils
.extractTlsConfig(serverConf.getPinotConfig(), CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null,
_queryExecutor, _serverMetrics, _accessControl);
serverConf.isGrpcTlsServerEnabled() ? TlsUtils.extractTlsConfig(serverConf.getPinotConfig(),
CommonConstants.Server.SERVER_GRPCTLS_PREFIX) : null, _queryExecutor, _serverMetrics, _accessControl);
} else {
_grpcQueryServer = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.pinot.server.worker;

import org.apache.helix.HelixManager;
import org.apache.pinot.common.config.TlsConfig;
import org.apache.pinot.common.metrics.ServerMetrics;
import org.apache.pinot.core.data.manager.InstanceDataManager;
import org.apache.pinot.query.runtime.QueryRunner;
Expand All @@ -37,19 +38,20 @@ public class WorkerQueryServer {
private QueryRunner _queryRunner;
private InstanceDataManager _instanceDataManager;
private ServerMetrics _serverMetrics;
private TlsConfig _tlsConfig;

public WorkerQueryServer(PinotConfiguration configuration, InstanceDataManager instanceDataManager,
HelixManager helixManager, ServerMetrics serverMetrics) {
HelixManager helixManager, ServerMetrics serverMetrics, TlsConfig tlsConfig) {
_configuration = toWorkerQueryConfig(configuration);
_helixManager = helixManager;
_instanceDataManager = instanceDataManager;
_tlsConfig = tlsConfig;
_serverMetrics = serverMetrics;
_queryServicePort =
_configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
_queryServicePort = _configuration.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
_queryRunner = new QueryRunner();
_queryRunner.init(_configuration, _instanceDataManager, _helixManager, _serverMetrics);
_queryWorkerService = new QueryServer(_queryServicePort, _queryRunner);
_queryWorkerService = new QueryServer(_queryServicePort, _queryRunner, _tlsConfig);
}

private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configuration) {
Expand All @@ -62,17 +64,15 @@ private static PinotConfiguration toWorkerQueryConfig(PinotConfiguration configu
CommonConstants.Helix.SERVER_INSTANCE_PREFIX_LENGTH) : instanceId;
newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_HOSTNAME, hostname);
}
int runnerPort = newConfig.getProperty(
CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT,
int runnerPort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_RUNNER_PORT);
if (runnerPort == -1) {
runnerPort =
newConfig.getProperty(CommonConstants.Server.CONFIG_OF_GRPC_PORT, CommonConstants.Server.DEFAULT_GRPC_PORT);
newConfig.addProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_RUNNER_PORT, runnerPort);
}
int servicePort =
newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
int servicePort = newConfig.getProperty(CommonConstants.MultiStageQueryRunner.KEY_OF_QUERY_SERVER_PORT,
CommonConstants.MultiStageQueryRunner.DEFAULT_QUERY_SERVER_PORT);
if (servicePort == -1) {
servicePort = newConfig.getProperty(CommonConstants.Helix.KEY_OF_SERVER_NETTY_PORT,
CommonConstants.Helix.DEFAULT_SERVER_NETTY_PORT);
Expand Down
Loading