From 4aab3a5938f204084e7df4b4d5863f6aa279751b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Thu, 20 Jun 2024 08:56:20 +0200 Subject: [PATCH] feat: support objects and lists in request mapping claims (#699) * fixes #683 and #674 Co-authored-by: ybelmekk --- build.gradle.kts | 1 + .../mock/oauth2/extensions/Template.kt | 31 +++++++++ .../mock/oauth2/token/OAuth2TokenCallback.kt | 42 +++--------- .../mock/oauth2/extensions/TemplateTest.kt | 28 ++++++++ .../oauth2/token/OAuth2TokenCallbackTest.kt | 67 +++++++++++++++++++ 5 files changed, 136 insertions(+), 33 deletions(-) create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/extensions/Template.kt create mode 100644 src/test/kotlin/no/nav/security/mock/oauth2/extensions/TemplateTest.kt diff --git a/build.gradle.kts b/build.gradle.kts index a0454a28..2d55bb89 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -66,6 +66,7 @@ dependencies { implementation("com.fasterxml.jackson.module:jackson-module-kotlin:$jacksonVersion") implementation("org.freemarker:freemarker:$freemarkerVersion") implementation("org.bouncycastle:bcpkix-jdk18on:$bouncyCastleVersion") + implementation("org.jetbrains.kotlinx:kotlinx-serialization-json:1.6.3") testImplementation("org.assertj:assertj-core:$assertjVersion") testImplementation("org.junit.jupiter:junit-jupiter-api:$junitJupiterVersion") testImplementation("org.junit.jupiter:junit-jupiter-params:$junitJupiterVersion") diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/extensions/Template.kt b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/Template.kt new file mode 100644 index 00000000..24465724 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/extensions/Template.kt @@ -0,0 +1,31 @@ +package no.nav.security.mock.oauth2.extensions + +/** + * Replaces all template values denoted with ${key} in a map with the corresponding values from the templates map. + * + * @param templates a map of template values + * @return a new map with all template values replaced + */ +fun Map.replaceValues(templates: Map): Map { + fun replaceTemplateString( + value: String, + templates: Map, + ): String { + val regex = Regex("""\$\{(\w+)\}""") + return regex.replace(value) { matchResult -> + val key = matchResult.groupValues[1] + templates[key]?.toString() ?: matchResult.value + } + } + + fun replaceValue(value: Any): Any { + return when (value) { + is String -> replaceTemplateString(value, templates) + is List<*> -> value.map { it?.let { replaceValue(it) } } + is Map<*, *> -> value.mapValues { v -> v.value?.let { replaceValue(it) } } + else -> value + } + } + + return this.mapValues { replaceValue(it.value) } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt index 3a4b9f39..4432e217 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallback.kt @@ -5,6 +5,7 @@ import com.nimbusds.oauth2.sdk.GrantType import com.nimbusds.oauth2.sdk.TokenRequest import no.nav.security.mock.oauth2.extensions.clientIdAsString import no.nav.security.mock.oauth2.extensions.grantType +import no.nav.security.mock.oauth2.extensions.replaceValues import no.nav.security.mock.oauth2.extensions.scopesWithoutOidcScopes import no.nav.security.mock.oauth2.extensions.tokenExchangeGrantOrNull import java.time.Duration @@ -89,27 +90,14 @@ data class RequestMappingTokenCallback( private fun List.getClaims(tokenRequest: TokenRequest): Map { val claims = firstOrNull { it.isMatch(tokenRequest) }?.claims ?: emptyMap() - val customParameters = tokenRequest.customParameters.mapValues { (_, value) -> value.first() } - val variables = - if (tokenRequest.grantType() == GrantType.CLIENT_CREDENTIALS) { - customParameters + ("clientId" to tokenRequest.clientIdAsString()) - } else { - customParameters - } - return claims.mapValues { (_, value) -> - when (value) { - is String -> replaceVariables(value, variables) - is List<*> -> - value.map { v -> - if (v is String) { - replaceVariables(v, variables) - } else { - v - } - } - else -> value - } - } + val templateParams = tokenRequest.toHTTPRequest().bodyAsFormParameters.mapValues { it.value.joinToString(separator = " ") } + + // in case client_id is not set as form param but as basic auth, we add it to the template params in two different formats for backwards compatibility + return claims.replaceValues( + templateParams + + mapOf("clientId" to tokenRequest.clientIdAsString()) + + mapOf("client_id" to tokenRequest.clientIdAsString()), + ) } private inline fun List.getClaimOrNull( @@ -118,18 +106,6 @@ data class RequestMappingTokenCallback( ): T? = getClaims(tokenRequest)[key] as? T private fun List.getTypeHeader(tokenRequest: TokenRequest) = firstOrNull { it.isMatch(tokenRequest) }?.typeHeader ?: JOSEObjectType.JWT.type - - private fun replaceVariables( - input: String, - replacements: Map, - ): String { - val pattern = Regex("""\$\{(\w+)}""") - return pattern.replace(input) { result -> - val variableName = result.groupValues[1] - val replacement = replacements[variableName] - replacement ?: result.value - } - } } data class RequestMapping( diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/extensions/TemplateTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/extensions/TemplateTest.kt new file mode 100644 index 00000000..a75eca6a --- /dev/null +++ b/src/test/kotlin/no/nav/security/mock/oauth2/extensions/TemplateTest.kt @@ -0,0 +1,28 @@ +package no.nav.security.mock.oauth2.extensions + +import io.kotest.assertions.asClue +import io.kotest.matchers.shouldBe +import org.junit.jupiter.api.Test + +class TemplateTest { + @Test + fun `template values in map should be replaced`() { + val templates = + mapOf( + "templateVal1" to "val1", + "templateVal2" to "val2", + "templateListVal" to "listVal1", + ) + + mapOf( + "object1" to mapOf("key1" to "\${templateVal1}"), + "object2" to "\${templateVal2}", + "nestedObject" to mapOf("nestedKey" to mapOf("nestedKeyAgain" to "\${templateVal2}")), + "list1" to listOf("\${templateListVal}"), + ).replaceValues(templates).asClue { + it["object1"] shouldBe mapOf("key1" to "val1") + it["list1"] shouldBe listOf("listVal1") + println(it) + } + } +} diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt index 368c83fc..a7523a2b 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/token/OAuth2TokenCallbackTest.kt @@ -124,6 +124,73 @@ internal class OAuth2TokenCallbackTest { issuer1.typeHeader(grantTypeShouldMatch) shouldBe "JWT" } } + + @Test + fun `token request with custom parameters in token request should include claims with placeholder names`() { + val request = + clientCredentialsRequest( + "scope" to "testscope:something another:scope", + "mock_token_type" to "custom", + ) + RequestMappingTokenCallback( + issuerId = "issuer1", + requestMappings = + listOf( + RequestMapping( + requestParam = "scope", + match = "testscope:.*", + claims = + mapOf( + "sub" to "\${clientId}", + "scope" to "\${scope}", + "mock_token_type" to "\${mock_token_type}", + ), + ), + ), + ).addClaims(request).asClue { + it shouldContainAll mapOf("sub" to clientId, "scope" to "testscope:something another:scope", "mock_token_type" to "custom") + } + } + } + + @Test + fun `token request with custom parameters in token request should include claims with placeholder names`() { + val request = + clientCredentialsRequest( + "mock_token_type" to "custom", + "participantId" to "participantId", + "actAs" to "actAs", + "readAs" to "readAs", + ) + RequestMappingTokenCallback( + issuerId = "issuer1", + requestMappings = + listOf( + RequestMapping( + requestParam = "mock_token_type", + match = "custom", + claims = + mapOf( + "https://daml.com/ledger-api" to + mapOf( + "participantId" to "\${participantId}", + "actAs" to listOf("\${actAs}"), + "readAs" to listOf("\${readAs}"), + ), + ), + ), + ), + ).addClaims(request).asClue { + it shouldContainAll + mapOf( + "https://daml.com/ledger-api" to + mapOf( + "participantId" to "participantId", + "actAs" to listOf("actAs"), + "readAs" to listOf("readAs"), + ), + ) + } } @Nested