diff --git a/zuul-core/build.gradle b/zuul-core/build.gradle index ada379f44a..c462a05194 100644 --- a/zuul-core/build.gradle +++ b/zuul-core/build.gradle @@ -6,8 +6,9 @@ dependencies { implementation libraries.guava // TODO(carl-mastrangelo): this can be implementation; remove Logger from public api points. api libraries.slf4j - implementation 'org.bouncycastle:bcprov-jdk18on:1.76' - implementation 'org.bouncycastle:bcpkix-jdk18on:1.76' + implementation 'org.bouncycastle:bcprov-jdk18on:1.78' + implementation 'org.bouncycastle:bcpkix-jdk18on:1.78' + implementation 'org.bouncycastle:bctls-jdk18on:1.78' implementation 'com.fasterxml.jackson.core:jackson-core:2.16.1' api 'com.fasterxml.jackson.core:jackson-databind:2.16.1' diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/http2/Http2OrHttpHandler.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/http2/Http2OrHttpHandler.java index 5d8c906669..1125c7c0ea 100644 --- a/zuul-core/src/main/java/com/netflix/zuul/netty/server/http2/Http2OrHttpHandler.java +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/http2/Http2OrHttpHandler.java @@ -20,6 +20,7 @@ import com.netflix.netty.common.channel.config.CommonChannelConfigKeys; import com.netflix.netty.common.http2.DynamicHttp2FrameLogger; import com.netflix.zuul.netty.server.BaseZuulChannelInitializer; +import com.netflix.zuul.netty.server.psk.TlsPskHandler; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -33,7 +34,9 @@ import io.netty.handler.logging.LogLevel; import io.netty.handler.ssl.ApplicationProtocolNames; import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.util.AttributeKey; + import java.util.function.Consumer; /** @@ -45,6 +48,8 @@ public class Http2OrHttpHandler extends ApplicationProtocolNegotiationHandler { public static final AttributeKey PROTOCOL_NAME = AttributeKey.valueOf("protocol_name"); + private static final String FALLBACK_APPLICATION_PROTOCOL = ApplicationProtocolNames.HTTP_1_1; + private static final DynamicHttp2FrameLogger FRAME_LOGGER = new DynamicHttp2FrameLogger(LogLevel.DEBUG, Http2FrameCodec.class); @@ -60,7 +65,7 @@ public Http2OrHttpHandler( ChannelHandler http2StreamHandler, ChannelConfig channelConfig, Consumer addHttpHandlerFn) { - super(ApplicationProtocolNames.HTTP_1_1); + super(FALLBACK_APPLICATION_PROTOCOL); this.http2StreamHandler = http2StreamHandler; this.maxConcurrentStreams = channelConfig.get(CommonChannelConfigKeys.maxConcurrentStreams); this.initialWindowSize = channelConfig.get(CommonChannelConfigKeys.initialWindowSize); @@ -70,6 +75,42 @@ public Http2OrHttpHandler( this.addHttpHandlerFn = addHttpHandlerFn; } + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent handshakeEvent) { + if (handshakeEvent.isSuccess()) { + TlsPskHandler tlsPskHandler = ctx.channel().pipeline().get(TlsPskHandler.class); + if (tlsPskHandler != null) { + // PSK mode + try { + String tlsPskApplicationProtocol = tlsPskHandler.getApplicationProtocol(); + configurePipeline( + ctx, + tlsPskApplicationProtocol != null + ? tlsPskApplicationProtocol + : FALLBACK_APPLICATION_PROTOCOL); + } catch (Throwable cause) { + exceptionCaught(ctx, cause); + } finally { + // Handshake failures are handled in exceptionCaught(...). + if (handshakeEvent.isSuccess()) { + removeSelfIfPresent(ctx); + } + } + } else { + // non PSK mode + super.userEventTriggered(ctx, evt); + } + } else { + // handshake failures + // TODO sunnys - handle PSK handshake failures + super.userEventTriggered(ctx, evt); + } + } else { + super.userEventTriggered(ctx, evt); + } + } + @Override protected void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception { if (ApplicationProtocolNames.HTTP_2.equals(protocol)) { @@ -118,4 +159,11 @@ private void configureHttp2(ChannelPipeline pipeline) { private void configureHttp1(ChannelPipeline pipeline) { addHttpHandlerFn.accept(pipeline); } + + private void removeSelfIfPresent(ChannelHandlerContext ctx) { + ChannelPipeline pipeline = ctx.pipeline(); + if (!ctx.isRemoved()) { + pipeline.remove(this); + } + } } diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/ExternalTlsPskProvider.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/ExternalTlsPskProvider.java new file mode 100644 index 0000000000..e3507b7f90 --- /dev/null +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/ExternalTlsPskProvider.java @@ -0,0 +1,25 @@ +/* + * Copyright 2024 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.zuul.netty.server.psk; + +import org.bouncycastle.tls.TlsPSKExternal; + +import java.util.Vector; + +public interface ExternalTlsPskProvider { + TlsPSKExternal provide(Vector clientPskIdentities); +} diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/TlsPskHandler.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/TlsPskHandler.java new file mode 100644 index 0000000000..100ccfc68f --- /dev/null +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/psk/TlsPskHandler.java @@ -0,0 +1,459 @@ +/* + * Copyright 2024 Netflix, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.netflix.zuul.netty.server.psk; + +import com.netflix.spectator.api.Registry; +import com.netflix.spectator.api.Timer; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.util.ReferenceCountUtil; +import org.bouncycastle.tls.AbstractTlsServer; +import org.bouncycastle.tls.AlertDescription; +import org.bouncycastle.tls.AlertLevel; +import org.bouncycastle.tls.CipherSuite; +import org.bouncycastle.tls.ProtocolName; +import org.bouncycastle.tls.ProtocolVersion; +import org.bouncycastle.tls.TlsCredentials; +import org.bouncycastle.tls.TlsFatalAlert; +import org.bouncycastle.tls.TlsPSKExternal; +import org.bouncycastle.tls.TlsServerProtocol; +import org.bouncycastle.tls.TlsUtils; +import org.bouncycastle.tls.crypto.TlsCrypto; +import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; +import javax.security.cert.X509Certificate; +import java.io.IOException; +import java.net.SocketAddress; +import java.security.Principal; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.util.Hashtable; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +public class TlsPskHandler extends ByteToMessageDecoder + implements ChannelOutboundHandler, ChannelInboundHandler { + + public static final Map SUPPORTED_TLS_PSK_CIPHER_SUITE_MAP = Map.of( + CipherSuite.TLS_AES_128_GCM_SHA256, + "TLS_AES_128_GCM_SHA256", + CipherSuite.TLS_AES_256_GCM_SHA384, + "TLS_AES_256_GCM_SHA384"); + + private final Registry registry; + private final ExternalTlsPskProvider externalTlsPskProvider; + + private ZuulPskServer tlsPskServer; + + private TlsPskServerProtocol tlsPskServerProtocol; + + public TlsPskHandler(Registry registry, ExternalTlsPskProvider externalTlsPskProvider) { + super(); + this.registry = registry; + this.externalTlsPskProvider = externalTlsPskProvider; + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect( + ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) + throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (!(msg instanceof ByteBuf byteBufMsg)) { + ctx.write(msg, promise); + return; + } + byte[] appDataBytes = byteBufMsg.hasArray() ? byteBufMsg.array() : readDirect(byteBufMsg); + ReferenceCountUtil.safeRelease(byteBufMsg); + tlsPskServerProtocol.writeApplicationData(appDataBytes, 0, appDataBytes.length); + int availableOutputBytes = tlsPskServerProtocol.getAvailableOutputBytes(); + if (availableOutputBytes != 0) { + byte[] outputBytes = new byte[availableOutputBytes]; + tlsPskServerProtocol.readOutput(outputBytes, 0, availableOutputBytes); + ctx.write(Unpooled.wrappedBuffer(outputBytes), promise) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } + } + + @Override + public void flush(ChannelHandlerContext ctx) { + ctx.flush(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + final byte[] bytesRead = in.hasArray() ? in.array() : readDirect(in); + try { + tlsPskServerProtocol.offerInput(bytesRead); + } catch (TlsFatalAlert tlsFatalAlert) { + writeOutputIfAvailable(ctx); + return; + } + writeOutputIfAvailable(ctx); + final int appDataAvailable = tlsPskServerProtocol.getAvailableInputBytes(); + if (appDataAvailable > 0) { + byte[] appData = new byte[appDataAvailable]; + tlsPskServerProtocol.readInput(appData, 0, appDataAvailable); + out.add(Unpooled.wrappedBuffer(appData)); + } + } + + private void writeOutputIfAvailable(ChannelHandlerContext ctx) { + final int availableOutputBytes = tlsPskServerProtocol.getAvailableOutputBytes(); + // output is available immediately (handshake not complete), pipe that back to the client right away + if (availableOutputBytes != 0) { + byte[] outputBytes = new byte[availableOutputBytes]; + tlsPskServerProtocol.readOutput(outputBytes, 0, availableOutputBytes); + ctx.writeAndFlush(Unpooled.wrappedBuffer(outputBytes)) + .addListener(ChannelFutureListener.FIRE_EXCEPTION_ON_FAILURE); + } + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + tlsPskServer = + new ZuulPskServer(new JcaTlsCryptoProvider().create(new SecureRandom()), registry, externalTlsPskProvider, ctx); + tlsPskServerProtocol = new TlsPskServerProtocol(); + tlsPskServerProtocol.accept(tlsPskServer); + super.channelRegistered(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + super.channelRead(ctx, msg); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + } + + private static byte[] readDirect(ByteBuf byteBufMsg) { + int length = byteBufMsg.readableBytes(); + byte[] dest = new byte[length]; + byteBufMsg.readSlice(length).getBytes(0, dest); + return dest; + } + + /** + * Returns the name of the current application-level protocol. + * Returns: + * the protocol name or null if application-level protocol has not been negotiated + */ + public String getApplicationProtocol() { + return tlsPskServer!=null ? tlsPskServer.getApplicationProtocol() : null; + } + + public SSLSession getSession() { + return tlsPskServerProtocol!=null ? tlsPskServerProtocol.getSSLSession() : null; + } + + static class ZuulPskServer extends AbstractTlsServer { + + private static final Logger LOGGER = LoggerFactory.getLogger(ZuulPskServer.class); + + private static class PSKTimings { + private final Timer handshakeCompleteTimer; + + private Long handshakeStartTime; + + PSKTimings(Registry registry) { + handshakeCompleteTimer = registry.timer("zuul.psk.handshake.complete.time"); + } + + public void recordHandshakeStarting() { + handshakeStartTime = System.nanoTime(); + } + + public void recordHandshakeComplete() { + handshakeCompleteTimer.record(System.nanoTime() - handshakeStartTime, TimeUnit.NANOSECONDS); + } + } + + private final PSKTimings pskTimings; + + private final ExternalTlsPskProvider externalTlsPskProvider; + + private final ChannelHandlerContext ctx; + + + public ZuulPskServer( + TlsCrypto crypto, + Registry registry, + ExternalTlsPskProvider externalTlsPskProvider, ChannelHandlerContext ctx) { + super(crypto); + this.pskTimings = new PSKTimings(registry); + this.externalTlsPskProvider = externalTlsPskProvider; + this.ctx = ctx; + } + + @Override + public TlsCredentials getCredentials() { + return null; + } + + @Override + protected Vector getProtocolNames() { + Vector protocolNames = new Vector(); + protocolNames.addElement(ProtocolName.HTTP_1_1); + protocolNames.addElement(ProtocolName.HTTP_2_TLS); + return protocolNames; + } + + @Override + public void notifyHandshakeBeginning() throws IOException { + // TODO: sunnys - handshake timeouts + super.notifyHandshakeBeginning(); + pskTimings.recordHandshakeStarting(); + } + + @Override + public void notifyHandshakeComplete() throws IOException { + super.notifyHandshakeComplete(); + pskTimings.recordHandshakeComplete(); + ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + } + + @Override + protected ProtocolVersion[] getSupportedVersions() { + return ProtocolVersion.TLSv13.only(); + } + + @Override + protected int[] getSupportedCipherSuites() { + return TlsUtils.getSupportedCipherSuites( + getCrypto(), + SUPPORTED_TLS_PSK_CIPHER_SUITE_MAP.keySet().stream() + .mapToInt(Number::intValue) + .toArray()); + } + + @Override + public ProtocolVersion getServerVersion() throws IOException { + return super.getServerVersion(); + } + + @Override + public TlsPSKExternal getExternalPSK(Vector clientPskIdentities) { + return externalTlsPskProvider.provide(clientPskIdentities); + } + + @Override + public void notifyAlertRaised(short alertLevel, short alertDescription, String message, Throwable cause) { + super.notifyAlertRaised(alertLevel, alertDescription, message, cause); + Consumer loggerFunc = (alertLevel == AlertLevel.fatal) ? LOGGER::error : LOGGER::debug; + loggerFunc.accept("TLS/PSK server raised alert: " + AlertLevel.getText(alertLevel) + ", " + + AlertDescription.getText(alertDescription)); + if (message != null) { + loggerFunc.accept("> " + message); + } + if (cause != null) { + LOGGER.error("TLS/PSK alert stacktrace", cause); + } + } + + @Override + public void notifyAlertReceived(short alertLevel, short alertDescription) { + Consumer loggerFunc = (alertLevel == AlertLevel.fatal) ? LOGGER::error : LOGGER::debug; + loggerFunc.accept("TLS 1.3 PSK server received alert: " + AlertLevel.getText(alertLevel) + ", " + + AlertDescription.getText(alertDescription)); + } + + @Override + public void processClientExtensions(Hashtable clientExtensions) throws IOException { + if (context.getSecurityParametersHandshake().getClientRandom() == null) { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + super.processClientExtensions(clientExtensions); + } + + @Override + public Hashtable getServerExtensions() throws IOException { + if (context.getSecurityParametersHandshake().getServerRandom() == null) { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + return super.getServerExtensions(); + } + + @Override + public void getServerExtensionsForConnection(Hashtable serverExtensions) throws IOException { + if (context.getSecurityParametersHandshake().getServerRandom() == null) { + throw new TlsFatalAlert(AlertDescription.internal_error); + } + super.getServerExtensionsForConnection(serverExtensions); + } + + public String getApplicationProtocol() { + ProtocolName protocolName = + context.getSecurityParametersConnection().getApplicationProtocol(); + if (protocolName!=null) { + return protocolName.getUtf8Decoding(); + } + return null; + } + } + + static class TlsPskServerProtocol extends TlsServerProtocol { + + public SSLSession getSSLSession() { + return new SSLSession() { + @Override + public byte[] getId() { + return tlsSession.getSessionID(); + } + + @Override + public SSLSessionContext getSessionContext() { + return null; + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void invalidate() {} + + @Override + public boolean isValid() { + return !isClosed(); + } + + @Override + public void putValue(String name, Object value) {} + + @Override + public Object getValue(String name) { + return null; + } + + @Override + public void removeValue(String name) {} + + @Override + public String[] getValueNames() { + return new String[0]; + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return new Certificate[0]; + } + + @Override + public Certificate[] getLocalCertificates() { + return new Certificate[0]; + } + + @Override + public X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + return new X509Certificate[0]; + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return null; + } + + @Override + public Principal getLocalPrincipal() { + return null; + } + + @Override + public String getCipherSuite() { + return SUPPORTED_TLS_PSK_CIPHER_SUITE_MAP.get( + getContext().getSecurityParameters().getCipherSuite()); + } + + @Override + public String getProtocol() { + return getContext().getServerVersion().getName(); + } + + @Override + public String getPeerHost() { + return null; + } + + @Override + public int getPeerPort() { + return 0; + } + + @Override + public int getPacketBufferSize() { + return 0; + } + + @Override + public int getApplicationBufferSize() { + return 0; + } + }; + } + } +} \ No newline at end of file diff --git a/zuul-core/src/main/java/com/netflix/zuul/netty/server/ssl/SslHandshakeInfoHandler.java b/zuul-core/src/main/java/com/netflix/zuul/netty/server/ssl/SslHandshakeInfoHandler.java index 79a9154434..c3ec4230c5 100644 --- a/zuul-core/src/main/java/com/netflix/zuul/netty/server/ssl/SslHandshakeInfoHandler.java +++ b/zuul-core/src/main/java/com/netflix/zuul/netty/server/ssl/SslHandshakeInfoHandler.java @@ -23,6 +23,7 @@ import com.netflix.spectator.api.NoopRegistry; import com.netflix.spectator.api.Registry; import com.netflix.zuul.netty.ChannelUtils; +import com.netflix.zuul.netty.server.psk.TlsPskHandler; import com.netflix.zuul.passport.CurrentPassport; import com.netflix.zuul.passport.PassportState; import io.netty.channel.ChannelHandlerContext; @@ -33,15 +34,16 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslHandshakeCompletionEvent; import io.netty.util.AttributeKey; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; import java.nio.channels.ClosedChannelException; import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.util.regex.Matcher; import java.util.regex.Pattern; -import javax.net.ssl.SSLException; -import javax.net.ssl.SSLSession; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * Stores info about the client and server's SSL certificates in the context, after a successful handshake. @@ -81,10 +83,13 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc CurrentPassport.fromChannel(ctx.channel()).add(PassportState.SERVER_CH_SSL_HANDSHAKE_COMPLETE); - SslHandler sslhandler = ctx.channel().pipeline().get(SslHandler.class); - SSLSession session = sslhandler.engine().getSession(); + SSLSession session = getSSLSession(ctx); + if (session == null) { + logger.warn("Error getting the SSL handshake info. SSLSession is null"); + return; + } - ClientAuth clientAuth = whichClientAuthEnum(sslhandler); + ClientAuth clientAuth = whichClientAuthEnum(ctx); Certificate serverCert = null; X509Certificate peerCert = null; @@ -184,7 +189,24 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc super.userEventTriggered(ctx, evt); } - private ClientAuth whichClientAuthEnum(SslHandler sslhandler) { + private SSLSession getSSLSession(ChannelHandlerContext ctx) { + SslHandler sslhandler = ctx.channel().pipeline().get(SslHandler.class); + if (sslhandler != null) { + return sslhandler.engine().getSession(); + } + TlsPskHandler tlsPskHandler = ctx.channel().pipeline().get(TlsPskHandler.class); + if (tlsPskHandler != null) { + return tlsPskHandler.getSession(); + } + return null; + } + + private ClientAuth whichClientAuthEnum(ChannelHandlerContext ctx) { + SslHandler sslhandler = ctx.channel().pipeline().get(SslHandler.class); + if (sslhandler == null) { + return ClientAuth.NONE; + } + ClientAuth clientAuth; if (sslhandler.engine().getNeedClientAuth()) { clientAuth = ClientAuth.REQUIRE;