diff --git a/README.md b/README.md index cd1e0bf..c120e88 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,9 @@ AbsaOSS Common Login service using JWT Public key signatures To interact with the service, most notable endpoints are - `/token/generate` to generate access & refresh tokens - `/token/refresh` to obtain a new access token with a still-valid refresh token - - `/token/public-key` to obtain public key to verify tokens including their validity window + - `/token/public-key` to obtain the currently signing public key to verify tokens including their validity window + - `/token/public-keys` to obtain all available public keys including the current and previously rotated keys. + - `/token/public-key-jwks` gives same data as `/token/public-keys` but in the form of a JSON Web Key Set. Please, refer to the [API documentation](#api-documentation) below for details of the endpoints. @@ -201,14 +203,28 @@ loginsvc: access-exp-time: 15min refresh-exp-time: 9h key-rotation-time: 9h + key-lay-over-time: 15min + key-phase-out-time: 30min alg-name: "RS256" ``` There are a few important configuration values to be provided: - `access-exp-time` which indicates how long an access token is valid for, - `refresh-exp-time` which indicates how long a refresh token is valid for, - Optional property: `key-rotation-time` which indicates how often Key pairs are rotated. Rotation will be disabled if missing. +- Optional property: `key-lay-over-time` which indicates a delay after rotation before using the newly created key for signing. Lay-over will be disabled if missing. +- Optional property: `key-phase-out-time` which indicates the time to phase out the older key. Timer is scheduled after `key-lay-over-time` if enabled. Phase-out will be disabled if missing. - `alg-name` which indicates which algorithm is used to encode your keys. +Using the above values, the optional properties will give the following effect after the 1st rotation at 9 hours: +``` +t=0: keys rotation happens +t=0-14m: layover period: old key from before rotation is still used for signing. Both public keys available from public-keys endpoint. +t=15-44m: layover is over: new key from after rotation is used for signing. Both public keys available from public-keys endpoint. +t=45m+: phase-out happens: new key from after rotation is used for signing. Old Key is no longer available from public-keys endpoint. +``` +These properties cannot be enabled if rotation is not enabled. The combined values of these properties cannot be higher than the rotation time. + + To setup for AWS Secrets Manager, your config should look like so: ``` loginsvc: @@ -222,6 +238,8 @@ loginsvc: access-exp-time: 15min refresh-exp-time: 9h poll-time: 30min + key-lay-over-time: 15min + key-phase-out-time: 30min alg-name: "RS256" ``` Your AWS Secret must have at least 2 fields which correspond to the above properties: @@ -236,7 +254,17 @@ There are a few important configuration values to be provided: - `access-exp-time` which indicates how long an access token is valid for, - `refresh-exp-time` which indicates how long a refresh token is valid for, - Optional property:`poll-time` which indicates how often key pairs (`private-key-field-name` and `public-key-field-name`) are polled and fetched from AWS Secrets Manager. Polling will be disabled if missing. +- Optional property: `key-lay-over-time` which indicates a delay after rotation before using the newly created key for signing. Lay-over will be disabled if missing. +- Optional property: `key-phase-out-time` which indicates the time to phase out the older key. Timer is scheduled after `key-lay-over-time` if enabled. Phase-out will be disabled if missing. - `alg-name` which indicates which algorithm is used to encode your keys. + Using the above values, the optional properties will give the following effect after the 1st rotation at 9 hours: +``` +t=0: keys rotation happens +t=0-14m: layover period: old key from before rotation is still used for signing. Both public keys available from public-keys endpoint. +t=15-44m: layover is over: new key from after rotation is used for signing. Both public keys available from public-keys endpoint. +t=45m+: phase-out happens: new key from after rotation is used for signing. Old Key is no longer available from public-keys endpoint. +``` +These properties cannot be enabled if polling is not enabled. Please note that only one configuration option (`loginsvc.rest.jwt.{aws-secrets-manager|generate-in-memory}`) can be used at a time. diff --git a/api/src/main/resources/example.application.yaml b/api/src/main/resources/example.application.yaml index a8bee3b..f6aab73 100644 --- a/api/src/main/resources/example.application.yaml +++ b/api/src/main/resources/example.application.yaml @@ -7,7 +7,8 @@ loginsvc: access-exp-time: 15min refresh-exp-time: 9h key-rotation-time: 9h - key-phase-out-time: 30min + key-lay-over-time: 15min + key-phase-out-time: 15min alg-name: "RS256" #Instead of generating the key in memory #The Below Config allows for the application to fetch keys from AWS Secrets Manager. @@ -19,7 +20,8 @@ loginsvc: #access-exp-time: 15min #refresh-exp-time: 9h #poll-time: 5min - #key-phase-out-time: 30min + #key-lay-over-time: 15min + #key-phase-out-time: 15min #alg-name: "RS256" config: # Generates git.properties file for use on info endpoint. diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala index 97a7c4d..b33ec41 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfig.scala @@ -19,13 +19,13 @@ package za.co.absa.loginsvc.rest.config.jwt import org.slf4j.LoggerFactory import za.co.absa.loginsvc.rest.config.validation.{ConfigValidationException, ConfigValidationResult} import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} -import za.co.absa.loginsvc.utils.AwsSecretsUtils +import za.co.absa.loginsvc.utils.{AwsSecretsUtils, SecretUtils} import java.security.{KeyFactory, KeyPair} import java.security.spec.{PKCS8EncodedKeySpec, X509EncodedKeySpec} import java.time.Instant import java.util.Base64 -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{Duration, FiniteDuration} case class AwsSecretsManagerKeyConfig( secretName: String, @@ -36,18 +36,58 @@ case class AwsSecretsManagerKeyConfig( accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration, pollTime: Option[FiniteDuration], + keyLayOverTime: Option[FiniteDuration], keyPhaseOutTime: Option[FiniteDuration] ) extends KeyConfig { private val logger = LoggerFactory.getLogger(classOf[AwsSecretsManagerKeyConfig]) override def keyRotationTime : Option[FiniteDuration] = pollTime - override def keyPair(): (KeyPair, Option[KeyPair]) = { + override def keyPair(): (KeyPair, Option[KeyPair]) = fetchKeySetsFromCloud() + + override def throwErrors(): Unit = this.validate().throwOnErrors() + + override def validate(): ConfigValidationResult = { + + val awsSecretsResults = Seq( + Option(secretName) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("secretName is empty"))), + + Option(region) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("region is empty"))), + + Option(privateKeyFieldName) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("privateKeyFieldName is empty"))), + + Option(publicKeyFieldName) + .map(_ => ConfigValidationSuccess) + .getOrElse(ConfigValidationError(ConfigValidationException("publicKeyFieldName is empty"))), + ) + + val awsSecretsResultsMerge = awsSecretsResults.foldLeft[ConfigValidationResult](ConfigValidationSuccess)(ConfigValidationResult.merge) + + super.validate().merge(awsSecretsResultsMerge) + } + + /** + * Fetches the keypair used for generating Java Web Tokens from Cloud. + * Fetches both the current as well as previously rotated keys if available. + * + * @param secretsUtils The methods used to fetch the keys. + * Mainly used for testing and can be left empty to use the default value in standard use. + * @return A tuple of the most current KeyPair as well as an option of the previously rotated keypair if available. + * The order and availability of the keys are dependant on key-lay-over and key-phase-out if enabled. + */ + private[jwt] def fetchKeySetsFromCloud(secretsUtils: SecretUtils = AwsSecretsUtils): (KeyPair, Option[KeyPair]) = { try { - val currentSecretsOption = AwsSecretsUtils.fetchSecret( + val currentSecretsOption = secretsUtils.fetchSecret( secretName, region, - Array(privateKeyFieldName, publicKeyFieldName) + Array(privateKeyFieldName, publicKeyFieldName), + None ) if(currentSecretsOption.isEmpty) @@ -58,19 +98,20 @@ case class AwsSecretsManagerKeyConfig( logger.info("AWSCURRENT Key Data successfully retrieved and parsed from AWS Secrets Manager") val previousSecretsOption = - AwsSecretsUtils.fetchSecret( - secretName, - region, - Array(privateKeyFieldName, publicKeyFieldName), - Some("AWSPREVIOUS") - ) + secretsUtils.fetchSecret( + secretName, + region, + Array(privateKeyFieldName, publicKeyFieldName), + Some("AWSPREVIOUS") + ) val previousKeyPair = previousSecretsOption.flatMap { previousSecrets => try { val keys = createKeyPair(previousSecrets.secretValue) logger.info("AWSPREVIOUS Key Data successfully retrieved and parsed from AWS Secrets Manager") - val exp = keyPhaseOutTime.exists(isExpired(currentSecrets.createTime, _)) - if(exp) { None } + val keyPhaseOutActive = keyPhaseOutTime.exists(kpot => + isExpired(currentSecrets.createTime, kpot + keyLayOverTime.getOrElse(Duration.Zero))) + if(keyPhaseOutActive) { None } else { Some(keys) } } catch { case e: Throwable => @@ -79,7 +120,15 @@ case class AwsSecretsManagerKeyConfig( } } - (currentKeyPair, previousKeyPair) + previousKeyPair.fold {(currentKeyPair, previousKeyPair)} { pk => + val keyLayOverActive = keyLayOverTime.exists(!isExpired(currentSecrets.createTime, _)) + if (!keyLayOverActive) { + (currentKeyPair, previousKeyPair) + } + else { + (pk, Some(currentKeyPair)) + } + } } catch { case e: Throwable => logger.error(s"Error occurred retrieving and decoding keys from AWS Secrets Manager", e) @@ -87,33 +136,6 @@ case class AwsSecretsManagerKeyConfig( } } - override def throwErrors(): Unit = this.validate().throwOnErrors() - - override def validate(): ConfigValidationResult = { - - val awsSecretsResults = Seq( - Option(secretName) - .map(_ => ConfigValidationSuccess) - .getOrElse(ConfigValidationError(ConfigValidationException("secretName is empty"))), - - Option(region) - .map(_ => ConfigValidationSuccess) - .getOrElse(ConfigValidationError(ConfigValidationException("region is empty"))), - - Option(privateKeyFieldName) - .map(_ => ConfigValidationSuccess) - .getOrElse(ConfigValidationError(ConfigValidationException("privateKeyFieldName is empty"))), - - Option(publicKeyFieldName) - .map(_ => ConfigValidationSuccess) - .getOrElse(ConfigValidationError(ConfigValidationException("publicKeyFieldName is empty"))), - ) - - val awsSecretsResultsMerge = awsSecretsResults.foldLeft[ConfigValidationResult](ConfigValidationSuccess)(ConfigValidationResult.merge) - - super.validate().merge(awsSecretsResultsMerge) - } - private def createKeyPair(secretKeys: Map[String, String]): KeyPair = { val publicKeySpec: X509EncodedKeySpec = new X509EncodedKeySpec( diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala index 0239a59..06571b7 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfig.scala @@ -23,13 +23,14 @@ import za.co.absa.loginsvc.rest.config.validation.{ConfigValidationException, Co import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} import java.security.KeyPair -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{Duration, FiniteDuration} case class InMemoryKeyConfig( algName: String, accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration, keyRotationTime: Option[FiniteDuration], + keyLayOverTime: Option[FiniteDuration], keyPhaseOutTime: Option[FiniteDuration] ) extends KeyConfig { @@ -45,12 +46,13 @@ case class InMemoryKeyConfig( } override def validate(): ConfigValidationResult = { - val keyPhaseOutTimeResult = if(keyPhaseOutTime.nonEmpty && keyRotationTime.nonEmpty - && keyPhaseOutTime.get > keyRotationTime.get) { - ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!")) + val optionalKeyTimeResult = if(keyRotationTime.nonEmpty + && (keyPhaseOutTime.getOrElse(Duration.Zero) + keyLayOverTime.getOrElse(Duration.Zero)) > keyRotationTime.get) { + ConfigValidationError(ConfigValidationException( + s"keyLayOverTime + keyPhaseOutTime must be lower than keyRotationTime!")) } else ConfigValidationSuccess - super.validate().merge(keyPhaseOutTimeResult) + super.validate().merge(optionalKeyTimeResult) } override def throwErrors(): Unit = this.validate().throwOnErrors() diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala index ee96104..da7abb1 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/config/jwt/KeyConfig.scala @@ -30,6 +30,7 @@ trait KeyConfig extends ConfigValidatable { def accessExpTime: FiniteDuration def refreshExpTime: FiniteDuration def keyRotationTime: Option[FiniteDuration] + def keyLayOverTime: Option[FiniteDuration] def keyPhaseOutTime: Option[FiniteDuration] def keyPair(): (KeyPair, Option[KeyPair]) def throwErrors(): Unit @@ -79,6 +80,14 @@ trait KeyConfig extends ConfigValidatable { ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) } else ConfigValidationSuccess + val keyLayoverTimeResult = if (keyLayOverTime.nonEmpty && keyLayOverTime.get < KeyConfig.minKeyLayOverTime) { + ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}")) + } else ConfigValidationSuccess + + val keyLayOverWithRotationResult = if (keyLayOverTime.nonEmpty && keyRotationTime.isEmpty) { + ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!")) + } else ConfigValidationSuccess + if (keyRotationTime.isEmpty) { logger.warn("keyRotationTime is not set in config, key-pair will not be rotated!") } @@ -93,6 +102,8 @@ trait KeyConfig extends ConfigValidatable { .merge(keyRotationTimeResult) .merge(keyPhaseOutTimeResult) .merge(keyPhaseOutWithRotationResult) + .merge(keyLayoverTimeResult) + .merge(keyLayOverWithRotationResult) } } @@ -101,4 +112,5 @@ object KeyConfig { val minRefreshExpTime: FiniteDuration = 10.milliseconds val minKeyRotationTime: FiniteDuration = 10.milliseconds val minKeyPhaseOutTime: FiniteDuration = 10.milliseconds + val minKeyLayOverTime: FiniteDuration = 10.milliseconds } diff --git a/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala b/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala index 26b8a72..d2cd741 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/rest/service/jwt/JWTService.scala @@ -26,7 +26,7 @@ import za.co.absa.loginsvc.model.User import za.co.absa.loginsvc.rest.config.jwt.InMemoryKeyConfig import za.co.absa.loginsvc.rest.config.provider.JwtConfigProvider import za.co.absa.loginsvc.rest.model.{AccessToken, RefreshToken, Token} -import za.co.absa.loginsvc.rest.service.jwt.JWTService.extractUserFrom +import za.co.absa.loginsvc.rest.service.jwt.JWTService.{extractUserFrom, parseWithKeys} import za.co.absa.loginsvc.rest.service.search.UserSearchService import java.security.interfaces.RSAPublicKey @@ -36,13 +36,13 @@ import java.util.Date import java.util.concurrent.{ScheduledThreadPoolExecutor, ThreadFactory, TimeUnit} import scala.collection.JavaConverters._ import scala.compat.java8.DurationConverters._ -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{Duration, FiniteDuration} @Service class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchService: UserSearchService) { private val logger = LoggerFactory.getLogger(classOf[JWTService]) - private val scheduler = new ScheduledThreadPoolExecutor(2, new ThreadFactory { + private val scheduler = new ScheduledThreadPoolExecutor(3, new ThreadFactory { override def newThread(r: Runnable): Thread = { val t = new Thread(r) t.setDaemon(true) @@ -107,21 +107,29 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe } def refreshTokens(accessToken: AccessToken, refreshToken: RefreshToken): (AccessToken, RefreshToken) = { - val oldAccessJws: Jws[Claims] = Jwts.parserBuilder() - .require("type", Token.TokenType.Access.toString) - .setSigningKey(primaryKeyPair.getPublic) - .setClock(() => Date.from(Instant.now().minus(jwtConfig.refreshExpTime.toJava))) // allowing expired access token - up to refresh token validity window - .build() - .parseClaimsJws(accessToken.token) // checks requirements: type=access, signature, custom validity window - - val userFromOldAccessToken: User = extractUserFrom(oldAccessJws.getBody) - - Jwts.parserBuilder() - .require("type", Token.TokenType.Refresh.toString) - .requireSubject(userFromOldAccessToken.name) - .setSigningKey(primaryKeyPair.getPublic) - .build() - .parseClaimsJws(refreshToken.token) // checks username, validity, and signature. + + val keyList: List[PublicKey] = List(primaryKeyPair.getPublic) ++ optionalKeyPair.map(_.getPublic).toList + + val oldAccessJws: Option[Jws[Claims]] = parseWithKeys( + accessToken, + keyList, + Token.TokenType.Access.toString, + Some(jwtConfig.refreshExpTime) + ) // checks requirements: type=access, signature, custom validity window + + if(oldAccessJws.isEmpty) + throw new JwtException("Tokens are incompatible with current keys. Please request new Tokens!") + + val userFromOldAccessToken: User = extractUserFrom(oldAccessJws.get.getBody) + + val refreshClaims = parseWithKeys( + refreshToken, + keyList, + Token.TokenType.Refresh.toString + ) + + if(refreshClaims.isEmpty) + throw new JwtException("Tokens are incompatible with current keys. Please request new Tokens!") val userUpdatedDetails = { try { @@ -190,15 +198,30 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe val scheduledFuture = scheduler.scheduleAtFixedRate(() => { logger.info("Attempting to Refresh for new Keys") try { - val (newPrimaryKeyPair, newOptionalKeyPair) = jwtConfig.keyPair() + var (newPrimaryKeyPair, newOptionalKeyPair) = jwtConfig.keyPair() logger.info("Keys have been Refreshed") + + jwtConfig.keyLayOverTime.foreach { kl => { + jwtConfig match { + case _: InMemoryKeyConfig => + newOptionalKeyPair.foreach { tok => + scheduleKeyLayOver(kl) + val temp = tok + newOptionalKeyPair = Some(newPrimaryKeyPair) + newPrimaryKeyPair = temp + } + case _ => + } + }} + jwtConfig.keyPhaseOutTime.foreach { kp => { jwtConfig match { case _: InMemoryKeyConfig => - scheduleKeyPhaseOut(kp) + scheduleKeyPhaseOut(kp + jwtConfig.keyLayOverTime.getOrElse(Duration.Zero)) case _ => } }} + primaryKeyPair = newPrimaryKeyPair optionalKeyPair = newOptionalKeyPair } @@ -211,7 +234,6 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe refreshTime.toMillis, TimeUnit.MILLISECONDS ) - Runtime.getRuntime.addShutdownHook(new Thread(() => { scheduledFuture.cancel(false) this.close() @@ -225,7 +247,23 @@ class JWTService @Autowired()(jwtConfigProvider: JwtConfigProvider, authSearchSe optionalKeyPair = None } }, phaseOutTime.toMillis, TimeUnit.MILLISECONDS) + Runtime.getRuntime.addShutdownHook(new Thread(() => { + scheduledFuture.cancel(false) + this.close() + })) + } + private def scheduleKeyLayOver(layOverTime: FiniteDuration): Unit = { + val scheduledFuture = scheduler.schedule(new Runnable { + override def run(): Unit = { + logger.info("Switching Signing key") + optionalKeyPair.foreach { okp => + val temp = okp + optionalKeyPair = Some(primaryKeyPair) + primaryKeyPair = temp + } + } + }, layOverTime.toMillis, TimeUnit.MILLISECONDS) Runtime.getRuntime.addShutdownHook(new Thread(() => { scheduledFuture.cancel(false) this.close() @@ -254,4 +292,29 @@ object JWTService { User(name, groups, optionalAttributes) } + + def parseWithKeys( + token: Token, + keys: List[PublicKey], + accessType: String, + clock: Option[FiniteDuration] = None + ): Option[Jws[Claims]] = { + keys.flatMap { key => + try { + val builder = Jwts.parserBuilder() + .require("type", accessType) + .setSigningKey(key) + + clock.foreach(time => builder.setClock(() => Date.from(Instant.now().minus(time.toJava)))) + + Some(builder.build().parseClaimsJws(token.token)) + } catch { + case e: MalformedJwtException => + throw e + case e: ExpiredJwtException => + throw e + case _: JwtException => None + } + }.headOption + } } diff --git a/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala b/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala index 939ed9a..7d4497f 100644 --- a/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala +++ b/api/src/main/scala/za/co/absa/loginsvc/utils/AwsSecretsUtils.scala @@ -24,7 +24,7 @@ import software.amazon.awssdk.services.secretsmanager.SecretsManagerClient import software.amazon.awssdk.services.secretsmanager.model.{GetSecretValueRequest, GetSecretValueResponse} import za.co.absa.loginsvc.rest.model.AwsSecret -object AwsSecretsUtils { +object AwsSecretsUtils extends SecretUtils { private val logger = LoggerFactory.getLogger(getClass) def fetchSecret( @@ -66,3 +66,12 @@ object AwsSecretsUtils { } } } + +trait SecretUtils { + def fetchSecret( + secretName: String, + region: String, + secretFields: Array[String], + versionStage: Option[String] = None + ): Option[AwsSecret] +} diff --git a/api/src/test/resources/application.yaml b/api/src/test/resources/application.yaml index 95b418e..f5f8c7f 100644 --- a/api/src/test/resources/application.yaml +++ b/api/src/test/resources/application.yaml @@ -5,7 +5,8 @@ loginsvc: generate-in-memory: access-exp-time: 15min refresh-exp-time: 10h - key-rotation-time: 5sec + key-rotation-time: 10sec + key-lay-over-time: 3sec key-phase-out-time: 3sec alg-name: "RS256" diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala index 22d1641..7ea04a4 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/AwsSecretsManagerKeyConfigTest.scala @@ -16,16 +16,24 @@ package za.co.absa.loginsvc.rest.config.jwt +import io.jsonwebtoken.SignatureAlgorithm +import io.jsonwebtoken.security.Keys +import org.mockito.ArgumentMatchers.{any, anyString, eq => eqMatch} +import org.mockito.Mockito.{mock, when} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import za.co.absa.loginsvc.rest.config.validation.ConfigValidationException import za.co.absa.loginsvc.rest.config.validation.ConfigValidationResult.{ConfigValidationError, ConfigValidationSuccess} +import za.co.absa.loginsvc.rest.model.AwsSecret +import za.co.absa.loginsvc.utils.SecretUtils +import java.time.Instant +import java.util.Base64 import scala.concurrent.duration._ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { - val awsSecretsManagerKeyConfig: AwsSecretsManagerKeyConfig = AwsSecretsManagerKeyConfig("Secret", + private val awsSecretsManagerKeyConfig: AwsSecretsManagerKeyConfig = AwsSecretsManagerKeyConfig("Secret", "region", "private", "public", @@ -33,7 +41,28 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { 15.minutes, 9.minutes, Option(30.minutes), - Option(15.minutes)) + Option(15.minutes), + Option(5.minutes)) + + private val mockSecretsUtil = mock(classOf[SecretUtils]) + private val currentKeyPair = Keys.keyPairFor(SignatureAlgorithm.RS256) + private val previousKeyPair = Keys.keyPairFor(SignatureAlgorithm.RS256) + private val currentKeyPairMaps = Map( + "private" -> Base64.getEncoder.encodeToString(currentKeyPair.getPrivate.getEncoded), + "public" -> Base64.getEncoder.encodeToString(currentKeyPair.getPublic.getEncoded)) + private val previousKeyPairMaps = Map( + "private" -> Base64.getEncoder.encodeToString(previousKeyPair.getPrivate.getEncoded), + "public" -> Base64.getEncoder.encodeToString(previousKeyPair.getPublic.getEncoded)) + + private val currentSecret = Some(AwsSecret(currentKeyPairMaps, Instant.now())) + private val currentSecretAfterLayOver = Some(AwsSecret(currentKeyPairMaps, + Instant.now().minus(16.minutes.toMillis, java.time.temporal.ChronoUnit.MILLIS))) + private val currentSecretAfterPhase = Some(AwsSecret(currentKeyPairMaps, + Instant.now().minus(21.minutes.toMillis, java.time.temporal.ChronoUnit.MILLIS))) + private val previousSecret = Some(AwsSecret(previousKeyPairMaps, + Instant.now().minus(6.hours.toMillis, java.time.temporal.ChronoUnit.MILLIS))) + + behavior of "validation" "awsSecretsManagerKeyConfig" should "validate expected content" in { awsSecretsManagerKeyConfig.validate() shouldBe ConfigValidationSuccess @@ -60,17 +89,104 @@ class AwsSecretsManagerKeyConfigTest extends AnyFlatSpec with Matchers { } it should "fail on non-negative keyPhaseOutTime" in { - awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe + awsSecretsManagerKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds), keyLayOverTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}")) } it should "fail on keyPhaseOutTime being configured without keyRotationTime" in { - awsSecretsManagerKeyConfig.copy(pollTime = None).validate() shouldBe + awsSecretsManagerKeyConfig.copy(pollTime = None, keyLayOverTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) } + it should "fail on non-negative keyLayOverTime" in { + awsSecretsManagerKeyConfig.copy(keyLayOverTime = Option(5.milliseconds)).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}")) + } + + it should "fail on keyLayOverTime being configured without keyRotationTime" in { + awsSecretsManagerKeyConfig.copy(pollTime = None, keyPhaseOutTime = None).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!")) + } + it should "fail on missing value" in { awsSecretsManagerKeyConfig.copy(secretName = null).validate() shouldBe ConfigValidationError(ConfigValidationException("secretName is empty")) } + + behavior of "fetchKeySetsFromCloud" + + it should "not use keyLayOver when previousKey is not available" in { + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))) + .thenReturn(currentSecret) + when(mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(Some("AWSPREVIOUS")))) + .thenReturn(None) + + val (currentKey, opPreviousKey) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKey.getPrivate == currentKeyPair.getPrivate) + assert(currentKey.getPublic == currentKeyPair.getPublic) + assert(opPreviousKey.isEmpty) + + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))).thenReturn(currentSecretAfterLayOver) + + val (currentKeyAfterLayover, optPreviousKeyAfterLayOver) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKeyAfterLayover.getPrivate == currentKeyPair.getPrivate) + assert(currentKeyAfterLayover.getPublic == currentKeyPair.getPublic) + assert(optPreviousKeyAfterLayOver.isEmpty) + } + + it should "not use keyPhaseOut when previousKey is not available" in { + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))) + .thenReturn(currentSecretAfterPhase) + when(mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(Some("AWSPREVIOUS")))) + .thenReturn(None) + + val (currentKey, opPreviousKey) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKey.getPrivate == currentKeyPair.getPrivate) + assert(currentKey.getPublic == currentKeyPair.getPublic) + assert(opPreviousKey.isEmpty) + } + + it should "use previousKey during the keyLayOver Period" in { + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))) + .thenReturn(currentSecret) + when(mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(Some("AWSPREVIOUS")))) + .thenReturn(previousSecret) + + val (currentKey, opPreviousKey) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKey.getPrivate == previousKeyPair.getPrivate) + assert(currentKey.getPublic == previousKeyPair.getPublic) + assert(opPreviousKey.get.getPrivate == currentKeyPair.getPrivate) + assert(opPreviousKey.get.getPublic == currentKeyPair.getPublic) + } + + it should "use currentKey after the keyLayOver Period" in { + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))) + .thenReturn(currentSecretAfterLayOver) + when(mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(Some("AWSPREVIOUS")))) + .thenReturn(previousSecret) + + val (currentKey, opPreviousKey) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKey.getPrivate == currentKeyPair.getPrivate) + assert(currentKey.getPublic == currentKeyPair.getPublic) + assert(opPreviousKey.get.getPrivate == previousKeyPair.getPrivate) + assert(opPreviousKey.get.getPublic == previousKeyPair.getPublic) + } + + it should "set the previousKey to None after the keyPhaseOut period" in { + when (mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(None))) + .thenReturn(currentSecretAfterPhase) + when(mockSecretsUtil.fetchSecret(anyString(),anyString(),any[Array[String]](),eqMatch(Some("AWSPREVIOUS")))) + .thenReturn(previousSecret) + + val (currentKey, opPreviousKey) = awsSecretsManagerKeyConfig.fetchKeySetsFromCloud(mockSecretsUtil) + + assert(currentKey.getPrivate == currentKeyPair.getPrivate) + assert(currentKey.getPublic == currentKeyPair.getPublic) + assert(opPreviousKey.isEmpty) + } } diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala index 1144961..fe29e3c 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/jwt/InMemoryKeyConfigTest.scala @@ -29,7 +29,8 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers { 15.minutes, 2.hours, Option(30.minutes), - Option(15.minutes)) + Option(15.minutes), + Option(5.minutes)) "inMemoryKeyConfig" should "validate expected content" in { inMemoryKeyConfig.validate() shouldBe ConfigValidationSuccess @@ -51,22 +52,33 @@ class InMemoryKeyConfigTest extends AnyFlatSpec with Matchers { } it should "fail on non-negative keyRotationTime" in { - inMemoryKeyConfig.copy(keyRotationTime = Option(5.milliseconds), keyPhaseOutTime = None).validate() shouldBe + inMemoryKeyConfig.copy(keyRotationTime = Option(5.milliseconds), keyPhaseOutTime = None, keyLayOverTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyRotationTime must be at least ${KeyConfig.minKeyRotationTime}")) } it should "fail on non-negative keyPhaseOutTime" in { - inMemoryKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds)).validate() shouldBe + inMemoryKeyConfig.copy(keyPhaseOutTime = Option(5.milliseconds), keyLayOverTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be at least ${KeyConfig.minKeyPhaseOutTime}")) } it should "fail on keyPhaseOutTime being configured without keyRotationTime" in { - inMemoryKeyConfig.copy(keyRotationTime = None).validate() shouldBe + inMemoryKeyConfig.copy(keyRotationTime = None, keyLayOverTime = None).validate() shouldBe ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime can only be enable if keyRotationTime is enable!")) } - it should "fail on keyPhaseOutTime being larger than keyRotationTime" in { + it should "fail on non-negative keyLayOverTime" in { + inMemoryKeyConfig.copy(keyLayOverTime = Option(5.milliseconds)).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyLayOverTime must be at least ${KeyConfig.minKeyLayOverTime}")) + } + + it should "fail on keyLayOverTime being configured without keyRotationTime" in { + inMemoryKeyConfig.copy(keyRotationTime = None, keyPhaseOutTime = None).validate() shouldBe + ConfigValidationError(ConfigValidationException(s"keyLayOverTime can only be enable if keyRotationTime is enable!")) + } + + it should "fail on keyLayOverTime + keyPhaseOutTime being larger than keyRotationTime" in { inMemoryKeyConfig.copy(keyRotationTime = Option(10.minutes)).validate() shouldBe - ConfigValidationError(ConfigValidationException(s"keyPhaseOutTime must be lower than keyRotationTime!")) + ConfigValidationError(ConfigValidationException(s"keyLayOverTime + keyPhaseOutTime must be lower than keyRotationTime!")) } + } diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProviderTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProviderTest.scala index d6e2c7b..0570197 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProviderTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/config/provider/ConfigProviderTest.scala @@ -37,7 +37,9 @@ class ConfigProviderTest extends AnyFlatSpec with Matchers { keyConfig.algName shouldBe "RS256" keyConfig.accessExpTime shouldBe FiniteDuration(15, TimeUnit.MINUTES) keyConfig.refreshExpTime shouldBe FiniteDuration(10, TimeUnit.HOURS) - keyConfig.keyRotationTime.get shouldBe FiniteDuration(5, TimeUnit.SECONDS) + keyConfig.keyRotationTime.get shouldBe FiniteDuration(10, TimeUnit.SECONDS) + keyConfig.keyLayOverTime.get shouldBe FiniteDuration(3, TimeUnit.SECONDS) + keyConfig.keyPhaseOutTime.get shouldBe FiniteDuration(3, TimeUnit.SECONDS) } "The ldapConfig properties" should "Match" in { diff --git a/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala b/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala index 7832419..85468f1 100644 --- a/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala +++ b/api/src/test/scala/za/co/absa/loginsvc/rest/service/jwt/JWTServiceTest.scala @@ -197,7 +197,7 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { } it should "fail with a unreadable tokens" in { - an[MalformedJwtException] should be thrownBy { + an[JwtException] should be thrownBy { jwtService.refreshTokens(AccessToken("abc.def.ghi"), RefreshToken("123.456.789")) } } @@ -205,7 +205,7 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { def customTimedJwtService(accessExpTime: FiniteDuration, refreshExpTime: FiniteDuration): JWTService = { val configP = new JwtConfigProvider { override def getJwtKeyConfig: KeyConfig = InMemoryKeyConfig( - "RS256", accessExpTime, refreshExpTime, None, None + "RS256", accessExpTime, refreshExpTime, None, None, None ) } @@ -280,11 +280,11 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { behavior of "keyRotation" - it should "rotate an public and private keys after 5 seconds" in { + it should "rotate public and private keys after 14 seconds" in { val initToken = jwtService.generateAccessToken(userWithoutGroups) val initPublicKey = jwtService.publicKeys - Thread.sleep(6 * 1000) + Thread.sleep(14000) val refreshedToken = jwtService.generateAccessToken(userWithoutGroups) assert(parseJWT(initToken).isFailure) @@ -294,11 +294,11 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { assert(initPublicKey._1 == jwtService.publicKeys._2.orNull) } - it should "phase out older keys after 8 seconds" in { + it should "phase out older keys after 17 seconds" in { val initToken = jwtService.generateAccessToken(userWithoutGroups) val initPublicKey = jwtService.publicKeys - Thread.sleep(6 * 1000) + Thread.sleep(14000) val refreshedToken = jwtService.generateAccessToken(userWithoutGroups) assert(parseJWT(initToken).isFailure) @@ -307,7 +307,28 @@ class JWTServiceTest extends AnyFlatSpec with BeforeAndAfterEach with Matchers { assert(initPublicKey._1 != jwtService.publicKeys._1) assert(initPublicKey._1 == jwtService.publicKeys._2.orNull) - Thread.sleep(3 * 1000) + Thread.sleep(3000) assert(jwtService.publicKeys._2.isEmpty) } + + it should "lay over keys after 15 seconds" in { + val initToken = jwtService.generateAccessToken(userWithoutGroups) + val initPublicKey = jwtService.publicKeys + + Thread.sleep(11000) + + assert(parseJWT(initToken).isSuccess) + assert(initPublicKey != jwtService.publicKeys) + assert(initPublicKey._1 == jwtService.publicKeys._1) + assert(initPublicKey._1 != jwtService.publicKeys._2.orNull) + + Thread.sleep(4000) + val refreshedToken = jwtService.generateAccessToken(userWithoutGroups) + + assert(parseJWT(initToken).isFailure) + assert(parseJWT(refreshedToken).isSuccess) + assert(initPublicKey != jwtService.publicKeys) + assert(initPublicKey._1 != jwtService.publicKeys._1) + assert(initPublicKey._1 == jwtService.publicKeys._2.orNull) + } }