Skip to content

Commit

Permalink
Breaking apart the PSK creation to an interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnysingh85 committed May 21, 2024
1 parent 29124da commit e906fb9
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 7 deletions.
3 changes: 2 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ subprojects {
mockito: 'org.mockito:mockito-core:5.+',
slf4j: "org.slf4j:slf4j-api:1.7.36",
truth: 'com.google.truth:truth:1.1.5',
awaitility: 'org.awaitility:awaitility:4.2.0'
awaitility: 'org.awaitility:awaitility:4.2.0',
lombok: 'org.projectlombok:lombok:1.18.30'
]
}

Expand Down
3 changes: 3 additions & 0 deletions zuul-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ apply plugin: "java-library"

dependencies {

compileOnly libraries.lombok
annotationProcessor(libraries.lombok)

implementation libraries.guava
// TODO(carl-mastrangelo): this can be implementation; remove Logger from public api points.
api libraries.slf4j
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

package com.netflix.zuul.netty.server.psk;

import org.bouncycastle.tls.TlsPSKExternal;

import java.util.Vector;

public interface ExternalTlsPskProvider {
TlsPSKExternal provide(Vector clientPskIdentities);
byte[] provide(byte[] clientPskIdentity, byte[] clientRandom) throws PskCreationFailureException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.netflix.zuul.netty.server.psk;

public class PskCreationFailureException extends Exception {

public enum TlsAlertMessage {
/**
* The server does not recognize the (client) PSK identity
*/
unknown_psk_identity,
/**
* The (client) PSK identity existed but the key was incorrect
*/
decrypt_error,
}

private final TlsAlertMessage tlsAlertMessage;

public PskCreationFailureException(TlsAlertMessage tlsAlertMessage, String message) {
super(message);
this.tlsAlertMessage = tlsAlertMessage;
}

public PskCreationFailureException(TlsAlertMessage tlsAlertMessage, String message, Throwable cause) {
super(message, cause);
this.tlsAlertMessage = tlsAlertMessage;
}

public TlsAlertMessage getTlsAlertMessage() {
return tlsAlertMessage;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,23 @@
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.util.ReferenceCountUtil;
import lombok.SneakyThrows;
import org.bouncycastle.tls.AbstractTlsServer;
import org.bouncycastle.tls.AlertDescription;
import org.bouncycastle.tls.AlertLevel;
import org.bouncycastle.tls.BasicTlsPSKExternal;
import org.bouncycastle.tls.CipherSuite;
import org.bouncycastle.tls.PRFAlgorithm;
import org.bouncycastle.tls.ProtocolName;
import org.bouncycastle.tls.ProtocolVersion;
import org.bouncycastle.tls.PskIdentity;
import org.bouncycastle.tls.TlsCredentials;
import org.bouncycastle.tls.TlsFatalAlert;
import org.bouncycastle.tls.TlsPSKExternal;
import org.bouncycastle.tls.TlsServerProtocol;
import org.bouncycastle.tls.TlsUtils;
import org.bouncycastle.tls.crypto.TlsCrypto;
import org.bouncycastle.tls.crypto.TlsSecret;
import org.bouncycastle.tls.crypto.impl.jcajce.JcaTlsCryptoProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -291,8 +296,21 @@ public ProtocolVersion getServerVersion() throws IOException {
}

@Override
@SneakyThrows // TODO: Ask BC folks to see if getExternalPSK can throw a checked exception
public TlsPSKExternal getExternalPSK(Vector clientPskIdentities) {
return externalTlsPskProvider.provide(clientPskIdentities);
byte[] clientPskIdentity = ((PskIdentity)clientPskIdentities.get(0)).getIdentity();
byte[] psk;
try{
psk = externalTlsPskProvider.provide(clientPskIdentity, this.context.getSecurityParametersHandshake().getClientRandom());
}catch (PskCreationFailureException e) {
throw switch (e.getTlsAlertMessage()) {
case unknown_psk_identity -> new TlsFatalAlert(AlertDescription.unknown_psk_identity, "Unknown or null client PSk identity");
case decrypt_error -> new TlsFatalAlert(AlertDescription.decrypt_error, "Invalid or expired client PSk identity");
};
}
TlsSecret pskTlsSecret = getCrypto().createSecret(psk);
int prfAlgorithm = getPRFAlgorithm13(getSelectedCipherSuite());
return new BasicTlsPSKExternal(clientPskIdentity, pskTlsSecret, prfAlgorithm);
}

@Override
Expand Down Expand Up @@ -348,6 +366,18 @@ public String getApplicationProtocol() {
}
return null;
}

private static int getPRFAlgorithm13(int cipherSuite) {
return switch (cipherSuite) {
case CipherSuite.TLS_AES_128_CCM_SHA256,
CipherSuite.TLS_AES_128_CCM_8_SHA256,
CipherSuite.TLS_AES_128_GCM_SHA256,
CipherSuite.TLS_CHACHA20_POLY1305_SHA256 -> PRFAlgorithm.tls13_hkdf_sha256;
case CipherSuite.TLS_AES_256_GCM_SHA384 -> PRFAlgorithm.tls13_hkdf_sha384;
case CipherSuite.TLS_SM4_CCM_SM3, CipherSuite.TLS_SM4_GCM_SM3 -> PRFAlgorithm.tls13_hkdf_sm3;
default -> -1;
};
}
}

static class TlsPskServerProtocol extends TlsServerProtocol {
Expand Down

0 comments on commit e906fb9

Please sign in to comment.