Skip to content

Commit

Permalink
Use Instant with configurable Clocks in the API (#32)
Browse files Browse the repository at this point in the history
  • Loading branch information
nefilim authored Feb 5, 2023
1 parent bb65a3a commit b8e4b3b
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 51 deletions.
17 changes: 14 additions & 3 deletions core/src/main/kotlin/io/github/nefilim/kjwt/ClaimValidation.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ import arrow.core.invalidNel
import arrow.core.validNel
import arrow.core.zip
import arrow.typeclasses.Semigroup
import java.time.LocalDateTime
import java.time.Clock
import java.time.temporal.ChronoUnit

sealed interface KJWTValidationError: KJWTVerificationError {
object TokenExpired: KJWTValidationError
Expand All @@ -31,9 +32,19 @@ object ClaimsVerification {

fun audience(audience: String): ClaimsValidator = requiredOptionClaim("audience", { audience() }, { it == audience }, KJWTValidationError.InvalidAudience)

val expired: ClaimsValidator = requiredOptionClaim("expired", { expiresAt() }, { it.isAfter(LocalDateTime.now()) }, KJWTValidationError.TokenExpired)
fun expired(clock: Clock = Clock.systemUTC()): ClaimsValidator = requiredOptionClaim(
"expired",
{ expiresAt() },
{ it.isAfter(clock.instant().truncatedTo(ChronoUnit.SECONDS)) },
KJWTValidationError.TokenExpired
)

val notBefore: ClaimsValidator = requiredOptionClaim("notBefore", { notBefore() }, { it.isBefore(LocalDateTime.now()) }, KJWTValidationError.TokenNotValidYet)
fun notBefore(clock: Clock = Clock.systemUTC()): ClaimsValidator = requiredOptionClaim(
"notBefore",
{ notBefore() },
{ it.isBefore(clock.instant().truncatedTo(ChronoUnit.SECONDS)) },
KJWTValidationError.TokenNotValidYet
)

fun <T>requiredOptionClaim(
name: String,
Expand Down
28 changes: 14 additions & 14 deletions core/src/main/kotlin/io/github/nefilim/kjwt/JWT.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import kotlinx.serialization.json.intOrNull
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.longOrNull
import java.time.LocalDateTime
import java.time.ZoneOffset
import java.time.Clock
import java.time.Instant
import java.util.*

enum class JOSEType(val id: String) {
Expand Down Expand Up @@ -58,9 +58,9 @@ interface JWTClaims {
fun issuer(): Option<String>
fun subject(): Option<String>
fun audience(): Option<String>
fun expiresAt(): Option<LocalDateTime>
fun notBefore(): Option<LocalDateTime>
fun issuedAt(): Option<LocalDateTime>
fun expiresAt(): Option<Instant>
fun notBefore(): Option<Instant>
fun issuedAt(): Option<Instant>
fun jwtID(): Option<String>
}

Expand Down Expand Up @@ -91,10 +91,10 @@ class JWT<T: JWSAlgorithm> private constructor(
fun issuer(i: String) = claim("iss", i)
fun subject(s: String) = claim("sub", s)
fun audience(a: String) = claim("aud", a)
fun expiresAt(d: LocalDateTime) = claim("exp", d.jwtNumericDate())
fun notBefore(d: LocalDateTime) = claim("nbf", d.jwtNumericDate())
fun issuedAt(d: LocalDateTime) = claim("iat", d.jwtNumericDate())
fun issuedNow() = issuedAt(LocalDateTime.now())
fun expiresAt(d: Instant) = claim("exp", d.jwtNumericDate())
fun notBefore(d: Instant) = claim("nbf", d.jwtNumericDate())
fun issuedAt(d: Instant) = claim("iat", d.jwtNumericDate())
fun issuedNow(clock: Clock = Clock.systemUTC()) = issuedAt(clock.instant())
fun jwtID(id: String) = claim("jti", id)

fun build(): Map<String, JsonElement> = Collections.unmodifiableMap(values)
Expand Down Expand Up @@ -182,9 +182,9 @@ class JWT<T: JWSAlgorithm> private constructor(
override fun issuer(): Option<String> = claimValue("iss")
override fun subject(): Option<String> = claimValue("sub")
override fun audience(): Option<String> = claimValue("aud")
override fun expiresAt(): Option<LocalDateTime> = claimValueAsLong("exp").map { it.fromJWTNumericDate() }
override fun notBefore(): Option<LocalDateTime> = claimValueAsLong("nbf").map { it.fromJWTNumericDate() }
override fun issuedAt(): Option<LocalDateTime> = claimValueAsLong("iat").map { it.fromJWTNumericDate() }
override fun expiresAt(): Option<Instant> = claimValueAsLong("exp").map { it.fromJWTNumericDate() }
override fun notBefore(): Option<Instant> = claimValueAsLong("nbf").map { it.fromJWTNumericDate() }
override fun issuedAt(): Option<Instant> = claimValueAsLong("iat").map { it.fromJWTNumericDate() }
override fun jwtID(): Option<String> = claimValue("jti")

// generated
Expand Down Expand Up @@ -226,8 +226,8 @@ data class SignedJWT<T: JWSAlgorithm>(
}

// https://datatracker.ietf.org/doc/html/rfc7519#section-2
internal fun LocalDateTime.jwtNumericDate(): Long = this.toEpochSecond(ZoneOffset.UTC)
internal fun Long.fromJWTNumericDate(): LocalDateTime = LocalDateTime.ofEpochSecond(this, 0, ZoneOffset.UTC)
internal fun Instant.jwtNumericDate(): Long = this.epochSecond
internal fun Long.fromJWTNumericDate(): Instant = Instant.ofEpochSecond(this, 0L)
fun jwtEncodeBytes(data: ByteArray): String = String(Base64.getUrlEncoder().encode(data)).trimEnd('=') // remove trailing '=' as per JWT spec
fun jwtDecodeString(data: String): String = String(Base64.getUrlDecoder().decode(data))
internal fun decodeString(data: String): ByteArray = Base64.getUrlDecoder().decode(data)
Expand Down
83 changes: 51 additions & 32 deletions core/src/test/kotlin/io/github/nefilim/kjwt/JWTSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@ import io.kotest.assertions.arrow.core.shouldBeValid
import io.kotest.core.spec.style.WordSpec
import io.kotest.matchers.shouldBe
import mu.KotlinLogging
import java.time.Instant
import java.time.LocalDateTime
import java.time.ZoneId
import java.time.ZoneOffset
import java.time.*
import java.time.temporal.ChronoUnit
import com.nimbusds.jwt.SignedJWT as NimbusSignedJWT

class JWTSpec: WordSpec() {
Expand All @@ -50,14 +48,14 @@ class JWTSpec: WordSpec() {
issuer("nefilim")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}) {
keyID().shouldBeSome().id shouldBe "123"
issuer().shouldBeSome() shouldBe "nefilim"
subject().shouldBeSome() shouldBe "1234567890"
claimValue("name").shouldBeSome() shouldBe "John Doe"
claimValueAsBoolean("admin").shouldBeSome() shouldBe true
issuedAt().shouldBeSome() shouldBe LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC"))
issuedAt().shouldBeSome() shouldBe Instant.ofEpochSecond(1516239022)
}
}

Expand All @@ -66,7 +64,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

jwt.encode() shouldBe "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0"
Expand All @@ -75,7 +73,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

jwtWithKeyID.encode() shouldBe "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCIsImtpZCI6IjEyMyJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0"
Expand All @@ -86,7 +84,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

JWT.decode(rawJWT.encode()).shouldBeRight().also {
Expand All @@ -112,7 +110,7 @@ class JWTSpec: WordSpec() {
"decode spec violating types" {
val rawJWT = es256 {
subject("1234567890")
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}
// create a token with a spec violating lowercase type of "jwt"
val jwtString = listOf(
Expand All @@ -125,7 +123,7 @@ class JWTSpec: WordSpec() {
"""
{
"sub": "${rawJWT.subject().getOrElse { "" }}",
"iat": ${rawJWT.issuedAt().map { it.toEpochSecond(ZoneOffset.UTC) }.getOrElse { 0 }}
"iat": ${rawJWT.issuedAt().map { it.epochSecond }.getOrElse { 0 }}
}
""".trimIndent()
).joinToString(".") {
Expand All @@ -140,7 +138,7 @@ class JWTSpec: WordSpec() {
"decode JWT with missing type header" {
val rawJWT = es256WithoutTypeHeader {
subject("1234567890")
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}
// create a token with a spec violating lowercase type of "jwt"
val jwtString = listOf(
Expand All @@ -152,7 +150,7 @@ class JWTSpec: WordSpec() {
"""
{
"sub": "${rawJWT.subject().getOrElse { "" }}",
"iat": ${rawJWT.issuedAt().map { it.toEpochSecond(ZoneOffset.UTC) }.getOrElse { 0 }}
"iat": ${rawJWT.issuedAt().map { it.epochSecond }.getOrElse { 0 }}
}
""".trimIndent()
).joinToString(".") {
Expand All @@ -172,7 +170,7 @@ class JWTSpec: WordSpec() {
claim("admin", true)
claim("thenumber", 42)
claim("thelist", thelist)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

JWT.decode(rawJWT.encode()).shouldBeRight().also {
Expand All @@ -197,7 +195,7 @@ class JWTSpec: WordSpec() {
issuer("nefilim")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

val (publicKey, privateKey) = generateKeyPair(rawJWT.header.algorithm)
Expand All @@ -221,7 +219,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

val (publicKey, privateKey) = generateKeyPair(rawJWT.header.algorithm)
Expand All @@ -244,7 +242,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

val secret = "iwFPzTFi41xBhlvqjYPiX4NKRFqAubl5zHAjeuK9s0MjvcCOgj84RgxRU2u8k7dUY1czPSCs4wAlePkLFTnsRpcaJdf07MJzloG63W1Mcfg9CCW9WOD80aOmkRnuYll5w8CYFj2qMP5D69XaGcjsu0rw6cjgkBhDDltSg5VZtDPYkGVuYw5NSUqk90PtKT9ZmF88bI2gadjhl3GS5ZBfOEisgNHnguQNfPFT3TDq8c5pLHoyAsErbNYaiwOjRfe2"
Expand All @@ -262,7 +260,7 @@ class JWTSpec: WordSpec() {
subject("1234567890")
claim("name", "John Doe")
claim("admin", true)
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
issuedAt(Instant.ofEpochSecond(1516239022))
}

// below stanzas should NOT compile, type system should prevent you from signing a JWT with the wrong kind of key
Expand All @@ -281,13 +279,13 @@ class JWTSpec: WordSpec() {
audience("http://thecompany.com")
claim("name", "John Doe")
claim("admin", true)
expiresAt(LocalDateTime.now().plusHours(1))
notBefore(LocalDateTime.now().minusMinutes(1))
expiresAt(Instant.now().plus(1, ChronoUnit.HOURS))
notBefore(Instant.now().minus(1, ChronoUnit.MINUTES))
issuedNow()
}

fun standardValidation(claims: JWTClaims): ValidatedNel<out KJWTVerificationError, JWTClaims> =
validateClaims(notBefore, expired, issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)
validateClaims(notBefore(), expired(), issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)

standardValidation(jwt).shouldBeValid()

Expand All @@ -297,9 +295,9 @@ class JWTSpec: WordSpec() {
audience("http://phish.com")
claim("name", "John Doe")
claim("admin", true)
expiresAt(LocalDateTime.now().minusHours(1))
notBefore(LocalDateTime.now().plusMinutes(1))
issuedAt(LocalDateTime.now())
expiresAt(Instant.now().minus(1, ChronoUnit.MINUTES))
notBefore(Instant.now().plus(1, ChronoUnit.MINUTES))
issuedAt(Instant.now())
}

standardValidation(invalidJWT).shouldBeInvalid().toSet()
Expand All @@ -314,14 +312,35 @@ class JWTSpec: WordSpec() {
)
}

"cross timezone issues and validation" {
fun standardValidation(claims: JWTClaims): ValidatedNel<out KJWTVerificationError, JWTClaims> =
validateClaims(notBefore(), expired())(claims)

val utcClock = Clock.systemUTC()
val defaultClock = Clock.systemDefaultZone()

val invalidJWT = es256() {
subject("123456789")
issuer("theothercompany")
audience("http://phish.com")
claim("name", "John Doe")
claim("admin", true)
expiresAt(defaultClock.instant().plus(1, ChronoUnit.MINUTES))
notBefore(utcClock.instant().minus(1, ChronoUnit.SECONDS))
issuedNow()
}

standardValidation(invalidJWT).shouldBeValid()
}

"support custom validations for required/optional claims" {
val jwt = es256() {
subject("1234567890")
issuer("theco")
claim("name", "John Doe")
claim("admin", true)
notBefore(LocalDateTime.now().plusMinutes(1))
issuedAt(LocalDateTime.ofInstant(Instant.ofEpochSecond(1516239022), ZoneId.of("UTC")))
notBefore(Instant.now().plus(1, ChronoUnit.MINUTES))
issuedAt(Instant.ofEpochSecond(1516239022))
}

validateClaims(requiredOptionClaim("admin", { claimValueAsBoolean("admin") }) { it })(jwt).shouldBeValid()
Expand All @@ -342,13 +361,13 @@ class JWTSpec: WordSpec() {
audience("http://thecompany.com")
claim("name", "John Doe")
claim("admin", true)
expiresAt(LocalDateTime.now().plusHours(1))
notBefore(LocalDateTime.now().minusMinutes(1))
expiresAt(Instant.now().plus(1, ChronoUnit.HOURS))
notBefore(Instant.now().minus(1, ChronoUnit.MINUTES))
issuedNow()
}

val standardValidation: ClaimsValidator = { claims ->
validateClaims(notBefore, expired, issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)
validateClaims(notBefore(), expired(), issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)
}
val signedJWT = jwt.sign(privateKey).shouldBeRight()
verify(signedJWT.rendered, ECPublicKeyProvider { publicKey.some() }, standardValidation, JWSES256Algorithm).shouldBeValid()
Expand All @@ -372,13 +391,13 @@ class JWTSpec: WordSpec() {
audience("http://thecompany.com")
claim("name", "John Doe")
claim("admin", true)
expiresAt(LocalDateTime.now().plusHours(1))
notBefore(LocalDateTime.now().minusMinutes(1))
expiresAt(Instant.now().plus(1, ChronoUnit.HOURS))
notBefore(Instant.now().minus(1, ChronoUnit.MINUTES))
issuedNow()
}

val standardValidation: ClaimsValidator = { claims ->
validateClaims(notBefore, expired, issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)
validateClaims(notBefore(), expired(), issuer("thecompany"), subject("1234567890"), audience("http://thecompany.com"))(claims)
}
val signedJWT = jwt.sign(privateKey).shouldBeRight()
verify(signedJWT.rendered, ECPublicKeyProvider { publicKey.some() }, standardValidation, JWSES256Algorithm).shouldBeValid()
Expand Down
4 changes: 2 additions & 2 deletions core/src/test/kotlin/io/github/nefilim/kjwt/ProblemSpec.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ class ProblemSpec: WordSpec() {
}

suspend fun <T: JWSAsymmetricAlgorithm<PubK, PrivK>, PubK: PublicKey, PrivK: PrivateKey>processToken(jwt: String, jwtOps: JWTOperations<T, PubK, PrivK>): ClaimsValidatorResult {
val validator = validateClaims(notBefore, expired)
val validator = validateClaims(notBefore(), expired())
return verify<T, PubK, PrivK>(jwt, jwtOps.keyProvider, jwtOps.algorithm, validator)
}

suspend fun <T: JWSECDSAAlgorithm>processKnownToken(jwt: String, jwtOps: ECJWTOperations<T>): ClaimsValidatorResult {
val validator = validateClaims(notBefore, expired)
val validator = validateClaims(notBefore(), expired())
return verify(jwt, jwtOps.keyProvider, validator, jwtOps.algorithm)
}

Expand Down

0 comments on commit b8e4b3b

Please sign in to comment.