-
Notifications
You must be signed in to change notification settings - Fork 59
/
OAuth2TokenCallback.kt
120 lines (96 loc) · 4.93 KB
/
OAuth2TokenCallback.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
package no.nav.security.mock.oauth2.token
import com.nimbusds.jose.JOSEObjectType
import com.nimbusds.oauth2.sdk.GrantType
import com.nimbusds.oauth2.sdk.TokenRequest
import no.nav.security.mock.oauth2.extensions.clientIdAsString
import no.nav.security.mock.oauth2.extensions.grantType
import no.nav.security.mock.oauth2.extensions.replaceValues
import no.nav.security.mock.oauth2.extensions.scopesWithoutOidcScopes
import no.nav.security.mock.oauth2.extensions.tokenExchangeGrantOrNull
import java.time.Duration
import java.util.UUID
interface OAuth2TokenCallback {
fun issuerId(): String
fun subject(tokenRequest: TokenRequest): String?
fun typeHeader(tokenRequest: TokenRequest): String
fun audience(tokenRequest: TokenRequest): List<String>
fun addClaims(tokenRequest: TokenRequest): Map<String, Any>
fun tokenExpiry(): Long
}
// TODO: for JwtBearerGrant and TokenExchange should be able to ovverride sub, make sub nullable and return some default
open class DefaultOAuth2TokenCallback
@JvmOverloads
constructor(
private val issuerId: String = "default",
private val subject: String = UUID.randomUUID().toString(),
private val typeHeader: String = JOSEObjectType.JWT.type,
// needs to be nullable in order to know if a list has explicitly been set, empty list should be a allowable value
private val audience: List<String>? = null,
private val claims: Map<String, Any> = emptyMap(),
private val expiry: Long = 3600,
) : OAuth2TokenCallback {
override fun issuerId(): String = issuerId
override fun subject(tokenRequest: TokenRequest): String =
when (GrantType.CLIENT_CREDENTIALS) {
tokenRequest.grantType() -> tokenRequest.clientIdAsString()
else -> subject
}
override fun typeHeader(tokenRequest: TokenRequest): String = typeHeader
override fun audience(tokenRequest: TokenRequest): List<String> {
val audienceParam = tokenRequest.tokenExchangeGrantOrNull()?.audience
return when {
audience != null -> audience
audienceParam != null -> audienceParam
tokenRequest.scope != null -> tokenRequest.scopesWithoutOidcScopes()
else -> listOf("default")
}
}
override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> =
mutableMapOf<String, Any>(
"tid" to issuerId,
).apply {
putAll(claims)
if (tokenRequest.grantType() == GrantType.AUTHORIZATION_CODE) {
put("azp", tokenRequest.clientIdAsString())
}
}
override fun tokenExpiry(): Long = expiry
}
data class RequestMappingTokenCallback(
val issuerId: String,
val requestMappings: List<RequestMapping>,
val tokenExpiry: Long = Duration.ofHours(1).toSeconds(),
) : OAuth2TokenCallback {
override fun issuerId(): String = issuerId
override fun subject(tokenRequest: TokenRequest): String? = requestMappings.getClaimOrNull(tokenRequest, "sub")
override fun typeHeader(tokenRequest: TokenRequest): String = requestMappings.getTypeHeader(tokenRequest)
override fun audience(tokenRequest: TokenRequest): List<String> = requestMappings.getClaimOrNull(tokenRequest, "aud") ?: emptyList()
override fun addClaims(tokenRequest: TokenRequest): Map<String, Any> = requestMappings.getClaims(tokenRequest)
override fun tokenExpiry(): Long = tokenExpiry
private fun List<RequestMapping>.getClaims(tokenRequest: TokenRequest): Map<String, Any> {
val claims = firstOrNull { it.isMatch(tokenRequest) }?.claims ?: emptyMap()
val templateParams = tokenRequest.toHTTPRequest().bodyAsFormParameters.mapValues { it.value.joinToString(separator = " ") }
// in case client_id is not set as form param but as basic auth, we add it to the template params in two different formats for backwards compatibility
return claims.replaceValues(
templateParams +
mapOf("clientId" to tokenRequest.clientIdAsString()) +
mapOf("client_id" to tokenRequest.clientIdAsString()),
)
}
private inline fun <reified T> List<RequestMapping>.getClaimOrNull(
tokenRequest: TokenRequest,
key: String,
): T? = getClaims(tokenRequest)[key] as? T
private fun List<RequestMapping>.getTypeHeader(tokenRequest: TokenRequest) = firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type
}
data class RequestMapping(
private val requestParam: String,
private val match: String,
val claims: Map<String, Any> = emptyMap(),
val typeHeader: String = JOSEObjectType.JWT.type,
) {
fun isMatch(tokenRequest: TokenRequest): Boolean =
tokenRequest.toHTTPRequest().bodyAsFormParameters[requestParam]?.any {
match == "*" || match == it || match.toRegex().matchEntire(it) != null
} ?: false
}