diff --git a/turms-gateway/src/main/java/im/turms/gateway/access/client/tcp/TcpServerFactory.java b/turms-gateway/src/main/java/im/turms/gateway/access/client/tcp/TcpServerFactory.java index 0aa4cb852e..0fb9cd5076 100644 --- a/turms-gateway/src/main/java/im/turms/gateway/access/client/tcp/TcpServerFactory.java +++ b/turms-gateway/src/main/java/im/turms/gateway/access/client/tcp/TcpServerFactory.java @@ -18,12 +18,12 @@ package im.turms.gateway.access.client.tcp; import java.net.InetSocketAddress; -import jakarta.annotation.Nullable; -import io.netty.channel.Channel; +import io.netty.channel.ChannelPipeline; import reactor.core.publisher.Sinks; import reactor.netty.Connection; import reactor.netty.DisposableServer; +import reactor.netty.NettyPipeline; import reactor.netty.tcp.TcpServer; import im.turms.gateway.access.client.common.channel.ServiceAvailabilityHandler; @@ -37,6 +37,7 @@ import im.turms.server.common.infra.healthcheck.ServerStatusManager; import im.turms.server.common.infra.metrics.TurmsMicrometerChannelMetricsRecorder; import im.turms.server.common.infra.net.BindException; +import im.turms.server.common.infra.net.SslContextSpecType; import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.constant.RemoteAddressSourceProxyProtocolMode; import im.turms.server.common.infra.property.env.common.SslProperties; @@ -56,7 +57,6 @@ public final class TcpServerFactory { private TcpServerFactory() { } - @Nullable public static DisposableServer create( TcpProperties tcpProperties, BlocklistService blocklistService, @@ -73,6 +73,8 @@ public static DisposableServer create( RemoteAddressSourceProxyProtocolMode proxyProtocolMode = tcpProperties.getRemoteAddressSource() .getProxyProtocolMode(); + + Sinks.One remoteAddressSink = Sinks.one(); TcpServer server = TcpServer.create() .host(host) .port(port) @@ -90,15 +92,14 @@ public static DisposableServer create( () -> new TurmsMicrometerChannelMetricsRecorder( MetricNameConst.CLIENT_NETWORK, "tcp")) - .doOnChannelInit((connectionObserver, channel, remoteAddress) -> channel.pipeline() - .addFirst("serviceAvailabilityHandler", serviceAvailabilityHandler)) - .handle((in, out) -> { - Connection connection = (Connection) in; - Sinks.One remoteAddressSink = Sinks.one(); + // Called for every new connection that is opened. + .doOnChannelInit((connectionObserver, channel, remoteAddress) -> { + ChannelPipeline pipeline = channel.pipeline(); + pipeline.addFirst("serviceAvailabilityHandler", serviceAvailabilityHandler); // Inbound - connection.addHandlerLast("varintLengthBasedFrameDecoder", + pipeline.addBefore(NettyPipeline.ReactiveBridge, + "varintLengthBasedFrameDecoder", CodecFactory.getExtendedVarintLengthBasedFrameDecoder(maxFrameLength)); - Channel channel = connection.channel(); if (RemoteAddressSourceProxyProtocolMode.REQUIRED == proxyProtocolMode) { HAProxyUtil.addProxyProtocolHandlers(channel.pipeline(), address -> { if (blocklistService.isIpBlocked(address.getAddress() @@ -108,8 +109,6 @@ public static DisposableServer create( remoteAddressSink.tryEmitValue(address); } }); - channel.config() - .setAutoRead(true); } else if (RemoteAddressSourceProxyProtocolMode.OPTIONAL == proxyProtocolMode) { HAProxyUtil.addProxyProtocolDetectorHandler(channel.pipeline(), address -> { if (blocklistService.isIpBlocked(address.getAddress() @@ -119,22 +118,35 @@ public static DisposableServer create( remoteAddressSink.tryEmitValue(address); } }); - channel.config() - .setAutoRead(true); } else { remoteAddressSink.tryEmitValue((InetSocketAddress) channel.remoteAddress()); } // Outbound - connection.addHandlerLast("varintLengthFieldPrepender", + pipeline.addLast("varintLengthFieldPrepender", CodecFactory.getVarintLengthFieldPrepender()); // For advanced operations, they encode objects to buffers themselves, // "protobufFrameEncoder" will just ignore them. But some simple // operations will pass TurmsNotification instances down, so we still need to // encode them. - connection.addHandlerLast("protobufFrameEncoder", + pipeline.addLast("protobufFrameEncoder", CodecFactory.getProtobufFrameEncoder()); - + }) + // Called when a connection is read (in/after channelActive(...)). + .handle((in, out) -> { + Connection connection = (Connection) in; + // Note: + // 1. We need to trigger the read event manually here. + // Otherwise, it will never read the inbound stream from the peer + // because we don't subscribe to the inbound stream until we get the peer + // address. + // 2. Although "setAutoRead" seems just setting the "autoRead" flag to true, + // it also triggers the read event under the hood. + // 3. Don't move "setAutoRead" to the callback of "doOnChannelInit" because the + // channel is not ready yet, and "setAutoRead" will not work. + connection.channel() + .config() + .setAutoRead(true); return remoteAddressSink.asMono() .flatMap(remoteAddress -> connectionListener.onAdded(connection, remoteAddress, @@ -144,7 +156,8 @@ public static DisposableServer create( }); SslProperties ssl = tcpProperties.getSsl(); if (ssl.isEnabled()) { - server.secure(spec -> SslUtil.configureSslContextSpec(spec, ssl, true)); + server = server.secure(spec -> SslUtil + .configureSslContextSpec(spec, SslContextSpecType.TCP, ssl, true)); } try { return server.bind() diff --git a/turms-gateway/src/main/java/im/turms/gateway/access/client/websocket/WebSocketServerFactory.java b/turms-gateway/src/main/java/im/turms/gateway/access/client/websocket/WebSocketServerFactory.java index 0795d1c958..6a896bfd70 100644 --- a/turms-gateway/src/main/java/im/turms/gateway/access/client/websocket/WebSocketServerFactory.java +++ b/turms-gateway/src/main/java/im/turms/gateway/access/client/websocket/WebSocketServerFactory.java @@ -47,6 +47,7 @@ import im.turms.server.common.infra.healthcheck.ServerStatusManager; import im.turms.server.common.infra.metrics.TurmsMicrometerChannelMetricsRecorder; import im.turms.server.common.infra.net.BindException; +import im.turms.server.common.infra.net.SslContextSpecType; import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.constant.RemoteAddressSourceHttpHeaderMode; import im.turms.server.common.infra.property.constant.RemoteAddressSourceProxyProtocolMode; @@ -135,7 +136,10 @@ public static DisposableServer create( } SslProperties ssl = webSocketProperties.getSsl(); if (ssl.isEnabled()) { - server.secure(spec -> SslUtil.configureSslContextSpec(spec, ssl, true), true); + server = server.secure( + spec -> SslUtil + .configureSslContextSpec(spec, SslContextSpecType.HTTP11, ssl, true), + true); } try { return server.bind() @@ -197,7 +201,7 @@ private static Publisher handleHttpRequest( // Note that: // 1. PingWebSocketFrame will be handled by Netty itself // 2. The flatMap is called by FluxReceive, which will release buffer after - // "onNext" returns + // "onNext" returns. .flatMap(frame -> frame instanceof BinaryWebSocketFrame ? Mono.just(frame.content()) : Mono.empty()); diff --git a/turms-gateway/src/main/java/im/turms/gateway/infra/ldap/LdapClient.java b/turms-gateway/src/main/java/im/turms/gateway/infra/ldap/LdapClient.java index 5871066391..459e470c66 100644 --- a/turms-gateway/src/main/java/im/turms/gateway/infra/ldap/LdapClient.java +++ b/turms-gateway/src/main/java/im/turms/gateway/infra/ldap/LdapClient.java @@ -52,6 +52,7 @@ import im.turms.gateway.infra.ldap.handler.LdapMessageEncoder; import im.turms.server.common.infra.logging.core.logger.Logger; import im.turms.server.common.infra.logging.core.logger.LoggerFactory; +import im.turms.server.common.infra.net.SslContextSpecType; import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.env.common.SslProperties; @@ -113,8 +114,11 @@ public boolean isConnected() { .port(port) .metrics(true, () -> new MicrometerChannelMetricsRecorder(LDAP_CLIENT, "ldap")); if (sslProperties != null && sslProperties.isEnabled()) { - client.secure(sslContextSpec -> SslUtil - .configureSslContextSpec(sslContextSpec, sslProperties, false)); + client = client + .secure(sslContextSpec -> SslUtil.configureSslContextSpec(sslContextSpec, + SslContextSpecType.TCP, + sslProperties, + false)); } return client.connect() .map(conn -> { diff --git a/turms-gateway/src/test/java/integration/access/client/tcp/TcpServerFactoryTests.java b/turms-gateway/src/test/java/integration/access/client/tcp/TcpServerFactoryTests.java index 01d5e83be9..8c9f77ed13 100644 --- a/turms-gateway/src/test/java/integration/access/client/tcp/TcpServerFactoryTests.java +++ b/turms-gateway/src/test/java/integration/access/client/tcp/TcpServerFactoryTests.java @@ -19,35 +19,52 @@ import java.io.IOException; import java.net.InetSocketAddress; +import java.nio.file.Path; import java.time.Duration; import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.ChannelPipeline; import io.netty.handler.codec.haproxy.HAProxyCommand; import io.netty.handler.codec.haproxy.HAProxyMessage; import io.netty.handler.codec.haproxy.HAProxyMessageEncoder; import io.netty.handler.codec.haproxy.HAProxyProtocolVersion; import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.mockito.Mockito; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.core.publisher.Sinks; -import reactor.netty.Connection; +import reactor.core.scheduler.Schedulers; import reactor.netty.DisposableServer; +import reactor.netty.NettyPipeline; +import reactor.netty.channel.ChannelOperations; import reactor.netty.tcp.TcpClient; import reactor.test.StepVerifier; import im.turms.gateway.access.client.tcp.TcpServerFactory; +import im.turms.gateway.domain.session.service.SessionService; import im.turms.server.common.access.client.codec.CodecFactory; +import im.turms.server.common.access.client.dto.notification.TurmsNotification; import im.turms.server.common.access.client.dto.request.TurmsRequest; import im.turms.server.common.access.client.dto.request.message.CreateMessageRequest; import im.turms.server.common.domain.blocklist.service.BlocklistService; import im.turms.server.common.infra.healthcheck.ServerStatusManager; import im.turms.server.common.infra.healthcheck.ServiceAvailability; +import im.turms.server.common.infra.net.SslContextSpecType; +import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.constant.RemoteAddressSourceProxyProtocolMode; +import im.turms.server.common.infra.property.env.common.SslProperties; import im.turms.server.common.infra.property.env.gateway.network.TcpProperties; import im.turms.server.common.infra.property.env.gateway.network.TcpRemoteAddressSourceProperties; import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import static im.turms.server.common.infra.unit.ByteSizeUnit.KB; /** * @author James Chen @@ -65,9 +82,160 @@ class TcpServerFactoryTests { private static final Duration TIMEOUT = Duration.ofSeconds(10); + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSendAndReceiveStreamSuccessfully_forBothSslAndNonSslConnections(boolean enableSsl) { + Duration timeout = Duration.ofSeconds(10); + + // 1. Prepare properties + SslProperties serverSslProperties; + SslProperties clientSslProperties; + if (enableSsl) { + String alias = "turms"; + String storePassword = "im.turms"; + + Path keyStorePath = Path.of("server-keystore-test.p12") + .toAbsolutePath() + .normalize(); + keyStorePath.toFile() + .deleteOnExit(); + String keyStorePathStr = keyStorePath.toString(); + + Path certificatePath = Path.of("server-certificate-test.pem") + .toAbsolutePath() + .normalize(); + certificatePath.toFile() + .deleteOnExit(); + String certificatePathStr = certificatePath.toString(); + + Path trustStorePath = Path.of("client-truststore-test.jks") + .toAbsolutePath() + .normalize(); + trustStorePath.toFile() + .deleteOnExit(); + String trustStorePathStr = trustStorePath.toString(); + + SslUtil.generateKeyStore(keyStorePathStr, + "RSA", + alias, + storePassword, + 2048, + 3650, + "Turms", + "san=ip:127.0.0.1,dns:localhost") + .then(Mono.defer(() -> SslUtil.generateCertificate(certificatePathStr, + keyStorePathStr, + alias, + storePassword))) + .then(Mono.defer(() -> SslUtil.generateTrustStore(trustStorePathStr, + certificatePathStr, + storePassword, + storePassword))) + .block(timeout); + serverSslProperties = new SslProperties().toBuilder() + .enabled(true) + .keyStore(keyStorePathStr) + .keyStorePassword(storePassword) + .build(); + clientSslProperties = new SslProperties().toBuilder() + .enabled(true) + .trustStore(trustStorePathStr) + .trustStorePassword(storePassword) + .build(); + } else { + serverSslProperties = new SslProperties(); + clientSslProperties = new SslProperties(); + } + + TcpProperties tcpProperties = new TcpProperties().toBuilder() + .remoteAddressSource(TcpRemoteAddressSourceProperties.builder() + .proxyProtocolMode(RemoteAddressSourceProxyProtocolMode.DISABLED) + .build()) + .ssl(serverSslProperties) + .port(0) + .wiretap(true) + .build(); + + // 2. Set up server + BlocklistService blocklistService = mock(BlocklistService.class); + when(blocklistService.isIpBlocked(any(byte[].class))).thenReturn(false); + + ServerStatusManager serverStatusManager = mock(ServerStatusManager.class); + when(serverStatusManager.getServiceAvailability()) + .thenReturn(ServiceAvailability.AVAILABLE); + + SessionService sessionService = mock(SessionService.class); + + int requestId = 123456; + DisposableServer server = TcpServerFactory.create(tcpProperties, + blocklistService, + serverStatusManager, + sessionService, + (connection, remoteAddress, in, out, onClose) -> { + out.sendObject(TurmsNotification.newBuilder() + .setRequestId(requestId) + .build()) + .then() + .subscribe(); + return Mono.never(); + }, + 8 * KB); + + Sinks.One responseSink = Sinks.one(); + + // 3. Set up client + TcpClient tcpClient = TcpClient.create() + .host("127.0.0.1") + .port(server.port()); + if (enableSsl) { + tcpClient = tcpClient + .secure(sslContextSpec -> SslUtil.configureSslContextSpec(sslContextSpec, + SslContextSpecType.TCP, + clientSslProperties, + false)); + } + tcpClient.doOnChannelInit((connectionObserver, channel, remoteAddress) -> { + ChannelPipeline pipeline = channel.pipeline(); + // Inbound + pipeline + // Ensure our inbound handlers are before the handler for FluxReceive. + .addBefore(NettyPipeline.ReactiveBridge, + "varintLengthBasedFrameDecoder", + CodecFactory.getVarintLengthBasedFrameDecoder()) + .addBefore(NettyPipeline.ReactiveBridge, + "turmsNotificationDecoder", + CodecFactory.getTurmsNotificationDecoder()) + // Outbound + .addLast("varintLengthFieldPrepender", + CodecFactory.getVarintLengthFieldPrepender()) + .addLast("protobufFrameEncoder", CodecFactory.getProtobufFrameEncoder()); + }) + .connect() + .doOnSuccess(conn -> { + ChannelOperations connection = (ChannelOperations) conn; + connection.receiveObject() + .cast(TurmsNotification.class) + .doOnNext(responseSink::tryEmitValue) + .subscribe(null, responseSink::tryEmitError); + connection.sendObject(TurmsRequest.newBuilder() + .setRequestId(requestId) + .build()) + .then() + .subscribe(null, responseSink::tryEmitError); + }) + .subscribeOn(Schedulers.newSingle("test-tcp-client")) + .block(Duration.ofSeconds(30)); + + // 4. Verify + StepVerifier.create(responseSink.asMono() + .timeout(timeout)) + .expectNextCount(1) + .verifyComplete(); + } + @Test void haproxy_v1_tcp4() { - verify(new HAProxyMessage( + shouldSendAndReceiveStreamSuccessfully_whenHaproxyEnabled(new HAProxyMessage( HAProxyProtocolVersion.V1, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, @@ -79,7 +247,7 @@ void haproxy_v1_tcp4() { @Test void haproxy_v2_tcp4() { - verify(new HAProxyMessage( + shouldSendAndReceiveStreamSuccessfully_whenHaproxyEnabled(new HAProxyMessage( HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, HAProxyProxiedProtocol.TCP4, @@ -89,7 +257,7 @@ void haproxy_v2_tcp4() { SERVER_PORT)); } - void verify(HAProxyMessage message) { + void shouldSendAndReceiveStreamSuccessfully_whenHaproxyEnabled(HAProxyMessage message) { // 1. set up server Sinks.One remoteClientAddressSink = Sinks.one(); Sinks.One turmsRequestSink = Sinks.one(); @@ -130,7 +298,7 @@ void verify(HAProxyMessage message) { try { // 2. set up client - Connection fakeProxy = TcpClient.create() + TcpClient.create() .port(SERVER_PORT) .host(SERVER_HOST) .doOnChannelInit( @@ -138,11 +306,18 @@ void verify(HAProxyMessage message) { .addFirst(HAProxyMessageEncoder.INSTANCE, CodecFactory.getVarintLengthFieldPrepender(), CodecFactory.getProtobufFrameEncoder())) - .connectNow(TIMEOUT); - fakeProxy.outbound() - .sendObject(Flux.just(message, TURMS_REQUEST)) - .then() - .block(TIMEOUT); + .connect() + .doOnSuccess(connection -> connection.outbound() + .sendObject(Flux.just(message, TURMS_REQUEST)) + .then() + .subscribe(null, t -> { + remoteClientAddressSink.tryEmitError(t); + turmsRequestSink.tryEmitError(t); + })) + .subscribe(null, t -> { + remoteClientAddressSink.tryEmitError(t); + turmsRequestSink.tryEmitError(t); + }); // 3. test StepVerifier.create(remoteClientAddressSink.asMono() diff --git a/turms-server-common/src/main/java/im/turms/server/common/access/admin/web/HttpServerFactory.java b/turms-server-common/src/main/java/im/turms/server/common/access/admin/web/HttpServerFactory.java index 04fd9d280e..5906aeb122 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/access/admin/web/HttpServerFactory.java +++ b/turms-server-common/src/main/java/im/turms/server/common/access/admin/web/HttpServerFactory.java @@ -53,7 +53,7 @@ public static HttpServer createHttpServer(AdminHttpProperties httpProperties) { .childOption(TCP_NODELAY, true) .runOn(LoopResourcesFactory.createForServer(ThreadNameConst.ADMIN_HTTP_PREFIX)) .metrics(true, () -> new TurmsMicrometerChannelMetricsRecorder(ADMIN_API, "http")); - return SslUtil.apply(http, ssl, null); + return SslUtil.apply(http, ssl, false); } } diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/client/TurmsTcpClient.java b/turms-server-common/src/main/java/im/turms/server/common/infra/client/TurmsTcpClient.java index f4af4c9c45..8b0d551c5a 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/infra/client/TurmsTcpClient.java +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/client/TurmsTcpClient.java @@ -19,7 +19,9 @@ import jakarta.annotation.Nullable; +import io.netty.channel.ChannelPipeline; import reactor.core.publisher.Mono; +import reactor.netty.NettyPipeline; import reactor.netty.channel.ChannelOperations; import reactor.netty.resources.LoopResources; import reactor.netty.tcp.TcpClient; @@ -68,22 +70,25 @@ public Mono connect(String host, int port, LoopResources loopResources) { .host(host) .port(port) .runOn(loopResources) - .connect() - .doOnNext(conn -> { - connection = (ChannelOperations) conn; - - connection + .doOnChannelInit((connectionObserver, channel, remoteAddress) -> { + ChannelPipeline pipeline = channel.pipeline(); + pipeline // Inbound - .addHandlerLast("varintLengthBasedFrameDecoder", + .addBefore(NettyPipeline.ReactiveBridge, + "varintLengthBasedFrameDecoder", CodecFactory.getVarintLengthBasedFrameDecoder()) - .addHandlerLast("turmsNotificationDecoder", + .addBefore(NettyPipeline.ReactiveBridge, + "turmsNotificationDecoder", CodecFactory.getTurmsNotificationDecoder()) // Outbound - .addHandlerFirst("protobufFrameEncoder", - CodecFactory.getProtobufFrameEncoder()) - .addHandlerFirst("varintLengthFieldPrepender", - CodecFactory.getVarintLengthFieldPrepender()); - + .addLast("varintLengthFieldPrepender", + CodecFactory.getVarintLengthFieldPrepender()) + .addLast("protobufFrameEncoder", + CodecFactory.getProtobufFrameEncoder()); + }) + .connect() + .doOnNext(conn -> { + connection = (ChannelOperations) conn; connection.receiveObject() .cast(TurmsNotification.class) .subscribe(this::handleResponse, diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionServer.java b/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionServer.java index 22bc4bdb33..f5535544a5 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionServer.java +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionServer.java @@ -32,6 +32,7 @@ import im.turms.server.common.infra.logging.core.logger.Logger; import im.turms.server.common.infra.logging.core.logger.LoggerFactory; import im.turms.server.common.infra.net.BindException; +import im.turms.server.common.infra.net.SslContextSpecType; import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.env.common.SslProperties; import im.turms.server.common.infra.thread.ThreadNameConst; @@ -92,7 +93,8 @@ public synchronized void blockUntilConnect() { .doOnConnection(connection -> connectionConsumer .accept((ChannelOperations) connection)); if (ssl.isEnabled()) { - tcpServer.secure(spec -> SslUtil.configureSslContextSpec(spec, ssl, true)); + tcpServer = tcpServer.secure(spec -> SslUtil + .configureSslContextSpec(spec, SslContextSpecType.TCP, ssl, true)); } server = tcpServer.bindNow(DurationConst.ONE_MINUTE); LOGGER.info("The local node server started on: {}:{}", host, currentPort); diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionService.java b/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionService.java index d236fa0583..57e200ad9f 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionService.java +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/cluster/service/connection/ConnectionService.java @@ -55,6 +55,7 @@ import im.turms.server.common.infra.logging.core.logger.Logger; import im.turms.server.common.infra.logging.core.logger.LoggerFactory; import im.turms.server.common.infra.logging.core.model.LogLevel; +import im.turms.server.common.infra.net.SslContextSpecType; import im.turms.server.common.infra.net.SslUtil; import im.turms.server.common.infra.property.env.common.SslProperties; import im.turms.server.common.infra.property.env.common.cluster.connection.ConnectionClientProperties; @@ -253,8 +254,10 @@ private Mono initTcpConnection(String host, int port) { .metrics(true, () -> new MicrometerChannelMetricsRecorder(NODE_TCP_CLIENT, "tcp")) .runOn(eventLoopGroupForClients); if (clientSsl.isEnabled()) { - client.secure(sslContextSpec -> SslUtil - .configureSslContextSpec(sslContextSpec, clientSsl, false)); + client = client.secure(sslContextSpec -> SslUtil.configureSslContextSpec(sslContextSpec, + SslContextSpecType.TCP, + clientSsl, + false)); } return client.connect(); } diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/io/ProcessUtil.java b/turms-server-common/src/main/java/im/turms/server/common/infra/io/ProcessUtil.java new file mode 100644 index 0000000000..a7f02cc7ee --- /dev/null +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/io/ProcessUtil.java @@ -0,0 +1,92 @@ +/* + * Copyright (C) 2019 The Turms Project + * https://github.com/turms-im/turms + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.turms.server.common.infra.io; + +import java.io.IOException; +import java.io.InputStream; + +import reactor.core.publisher.Mono; + +import im.turms.server.common.infra.lang.StringUtil; +import im.turms.server.common.infra.reactor.PublisherUtil; + +/** + * @author James Chen + */ +public final class ProcessUtil { + + private ProcessUtil() { + } + + public static Mono run(String... command) { + ProcessBuilder builder = new ProcessBuilder(command); + return PublisherUtil.fromFuture(() -> builder.start() + .onExit()) + .flatMap(process -> { + int exitValue = process.exitValue(); + if (exitValue == 0) { + return Mono.empty(); + } + InputStream inputStream = process.getInputStream(); + InputStream errorStream = process.getErrorStream(); + String info = null; + String error = null; + Exception suppressedException = null; + try { + info = inputStream.available() == 0 + ? null + : new String(inputStream.readAllBytes()); + error = errorStream.available() == 0 + ? null + : new String(errorStream.readAllBytes()); + } catch (IOException e) { + suppressedException = e; + } + String message; + if (StringUtil.isNotBlank(info)) { + if (StringUtil.isNotBlank(error)) { + message = "Exit value: " + + exitValue + + ". Info: " + + info + + ". Error: " + + error; + } else { + message = "Exit value: " + + exitValue + + ". Info: " + + info; + } + } else if (StringUtil.isNotBlank(error)) { + message = "Exit value: " + + exitValue + + ". Error: " + + error; + } else { + message = "Exit value: " + + exitValue; + } + RuntimeException exception = new RuntimeException(message); + if (suppressedException != null) { + exception.addSuppressed(suppressedException); + } + return Mono.error(exception); + }); + } + +} \ No newline at end of file diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslContextSpecType.java b/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslContextSpecType.java new file mode 100644 index 0000000000..862c00e6e7 --- /dev/null +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslContextSpecType.java @@ -0,0 +1,27 @@ +/* + * Copyright (C) 2019 The Turms Project + * https://github.com/turms-im/turms + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package im.turms.server.common.infra.net; + +/** + * @author James Chen + */ +public enum SslContextSpecType { + TCP, + HTTP11, + HTTP2 +} \ No newline at end of file diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslUtil.java b/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslUtil.java index 9d893c6f35..8d21d81825 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslUtil.java +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/net/SslUtil.java @@ -17,24 +17,27 @@ package im.turms.server.common.infra.net; +import java.io.File; import java.io.InputStream; import java.net.URL; +import java.nio.file.Path; import java.security.KeyStore; import java.util.Arrays; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.TrustManagerFactory; import jakarta.annotation.Nullable; -import lombok.SneakyThrows; -import org.springframework.boot.web.server.SslConfigurationValidator; -import org.springframework.boot.web.server.SslStoreProvider; import org.springframework.util.ResourceUtils; +import org.springframework.util.StringUtils; +import reactor.core.publisher.Mono; import reactor.netty.http.Http11SslContextSpec; +import reactor.netty.http.Http2SslContextSpec; import reactor.netty.http.server.HttpServer; import reactor.netty.tcp.AbstractProtocolSslContextSpec; -import reactor.netty.tcp.DefaultSslContextSpec; import reactor.netty.tcp.SslProvider; +import reactor.netty.tcp.TcpSslContextSpec; +import im.turms.server.common.infra.io.ProcessUtil; import im.turms.server.common.infra.property.env.common.SslProperties; /** @@ -42,82 +45,186 @@ */ public final class SslUtil { + private static final String KEYTOOL = Path.of(System.getProperty("java.home") + + File.separator + + "bin" + + File.separator + + "keytool") + .toString(); + private SslUtil() { } + public static Mono generateKeyStore( + String keystorePath, + String keyAlgorithm, + String alias, + String storePassword, + int keySize, + int validity, + String organizationName, + String extension) { + return ProcessUtil.run(KEYTOOL, + "-genkeypair", + "-dname", + "o=" + + organizationName, + "-keyalg", + keyAlgorithm, + "-keysize", + String.valueOf(keySize), + "-alias", + alias, + "-validity", + String.valueOf(validity), + "-keystore", + keystorePath, + "-storepass", + storePassword, + "-ext", + extension); + } + + public static Mono generateTrustStore( + String truststorePath, + String certificatePath, + String keyPassword, + String storePassword) { + return ProcessUtil.run(KEYTOOL, + "-import", + "-trustcacerts", + "-file", + certificatePath, + "-keypass", + keyPassword, + "-storepass", + storePassword, + "-keystore", + truststorePath, + "-noprompt"); + } + + public static Mono generateCertificate( + String certificatePath, + String keystorePath, + String alias, + String storePassword) { + return ProcessUtil.run(KEYTOOL, + "-exportcert", + "-keystore", + keystorePath, + "-alias", + alias, + "-storepass", + storePassword, + "-rfc", + "-file", + certificatePath); + } + public static void configureSslContextSpec( SslProvider.SslContextSpec sslContextSpec, + SslContextSpecType sslContextSpecType, SslProperties ssl, boolean forServer) { - if (ssl.isEnabled()) { - SslProvider.ProtocolSslContextSpec spec = forServer - ? createSslContextSpec(ssl, null) - : getClientContextBuilder(ssl); - sslContextSpec.sslContext(spec); + if (!ssl.isEnabled()) { + return; } + SslProvider.ProtocolSslContextSpec spec = forServer + ? createSslContextSpec(ssl, sslContextSpecType) + : getClientContextBuilder(ssl, sslContextSpecType); + sslContextSpec.sslContext(spec); } - public static HttpServer apply( - HttpServer server, - SslProperties ssl, - @Nullable SslStoreProvider sslStoreProvider) { + public static HttpServer apply(HttpServer server, SslProperties ssl, boolean isHttp2) { if (ssl.isEnabled()) { - return server - .secure(spec -> spec.sslContext(createSslContextSpec(ssl, sslStoreProvider))); + return server.secure(spec -> spec.sslContext(createSslContextSpec(ssl, + isHttp2 + ? SslContextSpecType.HTTP2 + : SslContextSpecType.HTTP11))); } return server; } - private static SslProvider.ProtocolSslContextSpec getClientContextBuilder(SslProperties ssl) { - return DefaultSslContextSpec.forClient() - .configure(builder -> { - builder.keyManager(getKeyManagerFactory(ssl, null)) - .trustManager(getTrustManagerFactory(ssl, null)); - if (ssl.getEnabledProtocols() != null) { - builder.protocols(ssl.getEnabledProtocols()); - } - if (ssl.getCiphers() != null) { - builder.ciphers(Arrays.asList(ssl.getCiphers())); - } - }); + private static SslProvider.ProtocolSslContextSpec getClientContextBuilder( + SslProperties ssl, + SslContextSpecType sslContextSpecType) { + AbstractProtocolSslContextSpec sslContextSpec = switch (sslContextSpecType) { + case TCP -> TcpSslContextSpec.forClient(); + case HTTP11 -> Http11SslContextSpec.forClient(); + case HTTP2 -> Http2SslContextSpec.forClient(); + }; + return sslContextSpec.configure(builder -> { + String keyStore = ssl.getKeyStore(); + if (keyStore != null) { + builder.keyManager(getKeyManagerFactory(ssl)); + } + String trustStore = ssl.getTrustStore(); + if (trustStore != null) { + builder.trustManager(getTrustManagerFactory(ssl)); + } + String[] enabledProtocols = ssl.getEnabledProtocols(); + if (enabledProtocols != null) { + builder.protocols(enabledProtocols); + } + String[] ciphers = ssl.getCiphers(); + if (ciphers != null) { + builder.ciphers(Arrays.asList(ciphers)); + } + }); } private static AbstractProtocolSslContextSpec createSslContextSpec( SslProperties ssl, - SslStoreProvider sslStoreProvider) { - return Http11SslContextSpec.forServer(getKeyManagerFactory(ssl, sslStoreProvider)) - .configure(builder -> { - builder.trustManager(getTrustManagerFactory(ssl, sslStoreProvider)); - if (ssl.getEnabledProtocols() != null) { - builder.protocols(ssl.getEnabledProtocols()); - } - if (ssl.getCiphers() != null) { - builder.ciphers(Arrays.asList(ssl.getCiphers())); - } - if (ssl.getClientAuth() != null) { - builder.clientAuth(ssl.getClientAuth()); - } - }); + SslContextSpecType sslContextSpecType) { + AbstractProtocolSslContextSpec sslContextSpec = switch (sslContextSpecType) { + case TCP -> TcpSslContextSpec.forServer(getKeyManagerFactory(ssl)); + case HTTP11 -> Http11SslContextSpec.forServer(getKeyManagerFactory(ssl)); + case HTTP2 -> Http2SslContextSpec.forServer(getKeyManagerFactory(ssl)); + }; + return sslContextSpec.configure(builder -> { + builder.trustManager(getTrustManagerFactory(ssl)); + if (ssl.getEnabledProtocols() != null) { + builder.protocols(ssl.getEnabledProtocols()); + } + if (ssl.getCiphers() != null) { + builder.ciphers(Arrays.asList(ssl.getCiphers())); + } + if (ssl.getClientAuth() != null) { + builder.clientAuth(ssl.getClientAuth()); + } + }); } - private static KeyManagerFactory getKeyManagerFactory( - SslProperties ssl, - @Nullable SslStoreProvider sslStoreProvider) { + private static KeyManagerFactory getKeyManagerFactory(SslProperties ssl) { + String password = ssl.getKeyPassword(); + String keyStorePassword = ssl.getKeyStorePassword(); + char[] keyPassword; + if (password == null) { + if (keyStorePassword == null) { + throw new IllegalArgumentException("The key password is required"); + } + keyPassword = keyStorePassword.toCharArray(); + } else { + keyPassword = password.toCharArray(); + } try { - KeyStore keyStore = getKeyStore(ssl, sslStoreProvider); - SslConfigurationValidator.validateKeyAlias(keyStore, ssl.getKeyAlias()); - KeyManagerFactory keyManagerFactory = ssl.getKeyAlias() == null - ? KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()) - : new ConfigurableAliasKeyManagerFactory( - ssl.getKeyAlias(), - KeyManagerFactory.getDefaultAlgorithm()); - char[] keyPassword = ssl.getKeyPassword() == null - ? null - : ssl.getKeyPassword() - .toCharArray(); - if (keyPassword == null && ssl.getKeyStorePassword() != null) { - keyPassword = ssl.getKeyStorePassword() - .toCharArray(); + KeyStore keyStore = getKeyStore(ssl); + String alias = ssl.getKeyAlias(); + KeyManagerFactory keyManagerFactory; + if (StringUtils.hasLength(alias)) { + if (!keyStore.containsAlias(alias)) { + throw new RuntimeException( + "The keystore does not contain the alias: \"" + + alias + + "\""); + } + keyManagerFactory = new ConfigurableAliasKeyManagerFactory( + alias, + KeyManagerFactory.getDefaultAlgorithm()); + } else { + keyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); } keyManagerFactory.init(keyStore, keyPassword); return keyManagerFactory; @@ -126,42 +233,27 @@ private static KeyManagerFactory getKeyManagerFactory( } } - private static KeyStore getKeyStore( - SslProperties ssl, - @Nullable SslStoreProvider sslStoreProvider) throws Exception { - if (sslStoreProvider != null) { - return sslStoreProvider.getKeyStore(); - } + private static KeyStore getKeyStore(SslProperties ssl) { return loadStore(ssl.getKeyStoreType(), ssl.getKeyStoreProvider(), ssl.getKeyStore(), ssl.getKeyStorePassword()); } - private static TrustManagerFactory getTrustManagerFactory( - SslProperties ssl, - SslStoreProvider sslStoreProvider) { + private static TrustManagerFactory getTrustManagerFactory(SslProperties ssl) { try { - KeyStore store = getTrustStore(ssl, sslStoreProvider); + KeyStore store = getTrustStore(ssl); TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(store); return trustManagerFactory; - } catch (Exception ex) { - throw new RuntimeException(ex); + } catch (Exception e) { + throw new RuntimeException("Failed to get the trust manager factory", e); } } - private static KeyStore getTrustStore( - SslProperties ssl, - @Nullable SslStoreProvider sslStoreProvider) { - if (sslStoreProvider != null) { - try { - return sslStoreProvider.getTrustStore(); - } catch (Exception e) { - throw new RuntimeException("Failed to get the instance of the trust store", e); - } - } + @Nullable + private static KeyStore getTrustStore(SslProperties ssl) { String trustStore = ssl.getTrustStore(); if (trustStore == null) { return null; @@ -172,15 +264,24 @@ private static KeyStore getTrustStore( ssl.getTrustStorePassword()); } - @SneakyThrows private static KeyStore loadStore( - String type, + @Nullable String type, @Nullable String provider, String resource, @Nullable String password) { - KeyStore store = provider == null - ? KeyStore.getInstance(type) - : KeyStore.getInstance(type, provider); + KeyStore store; + try { + if (type == null) { + type = resource.endsWith(".jks") + ? "JKS" + : "PKCS12"; + } + store = provider == null + ? KeyStore.getInstance(type) + : KeyStore.getInstance(type, provider); + } catch (Exception e) { + throw new RuntimeException("Failed to load key store", e); + } try { URL url = ResourceUtils.getURL(resource); try (InputStream stream = url.openStream()) { @@ -198,4 +299,4 @@ private static KeyStore loadStore( } } -} +} \ No newline at end of file diff --git a/turms-server-common/src/main/java/im/turms/server/common/infra/property/env/common/SslProperties.java b/turms-server-common/src/main/java/im/turms/server/common/infra/property/env/common/SslProperties.java index bf544e2fe6..ab80649ca9 100644 --- a/turms-server-common/src/main/java/im/turms/server/common/infra/property/env/common/SslProperties.java +++ b/turms-server-common/src/main/java/im/turms/server/common/infra/property/env/common/SslProperties.java @@ -18,12 +18,18 @@ package im.turms.server.common.infra.property.env.common; import io.netty.handler.ssl.ClientAuth; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.experimental.SuperBuilder; /** * @author James Chen */ +@AllArgsConstructor +@SuperBuilder(toBuilder = true) @Data +@NoArgsConstructor public class SslProperties { private boolean enabled; @@ -44,4 +50,4 @@ public class SslProperties { private String trustStoreProvider; private String protocol = "TLS"; -} +} \ No newline at end of file diff --git a/turms-service/src/main/resources/keystore-test.p12 b/turms-service/src/main/resources/keystore-test.p12 deleted file mode 100644 index 24437ecb6e..0000000000 Binary files a/turms-service/src/main/resources/keystore-test.p12 and /dev/null differ