From dac8b61d737023f5ff30561bd4064dd891ff129a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tommy=20Tr=C3=B8en?= Date: Tue, 8 Mar 2022 11:30:11 +0100 Subject: [PATCH] feature: add support for request/response interceptors, support CORS with credentials (#199) * feat: add request/response interceptor feature to route handling * feat: move CORS functionality into CorsInterceptor --- .../mock/oauth2/http/CorsInterceptor.kt | 40 ++++++++ .../oauth2/http/OAuth2HttpRequestHandler.kt | 14 +-- .../mock/oauth2/http/OAuth2HttpResponse.kt | 3 - .../mock/oauth2/http/OAuth2HttpRouter.kt | 97 +++++++++++++------ .../kotlin/ktor/client/OAuth2Client.kt | 2 + .../oauth2/e2e/CorsHeadersIntegrationTest.kt | 50 +++++++--- .../mock/oauth2/http/OAuth2HttpRouterTest.kt | 50 +++++++++- .../security/mock/oauth2/testutils/Http.kt | 9 +- 8 files changed, 205 insertions(+), 60 deletions(-) create mode 100644 src/main/kotlin/no/nav/security/mock/oauth2/http/CorsInterceptor.kt diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/CorsInterceptor.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/CorsInterceptor.kt new file mode 100644 index 00000000..2217e0c0 --- /dev/null +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/CorsInterceptor.kt @@ -0,0 +1,40 @@ +package no.nav.security.mock.oauth2.http + +import mu.KotlinLogging + +private val log = KotlinLogging.logger {} + +class CorsInterceptor( + private val allowedMethods: List = listOf("POST", "GET", "OPTIONS") +) : ResponseInterceptor { + + companion object HeaderNames { + const val ORIGIN = "origin" + const val ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials" + const val ACCESS_CONTROL_REQUEST_HEADERS = "access-control-request-headers" + const val ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers" + const val ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods" + const val ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin" + } + + override fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse { + val origin = request.headers[ORIGIN] + log.debug("intercept response if request origin header is set: $origin") + return if (origin != null) { + val headers = response.headers.newBuilder() + if (request.method == "OPTIONS") { + val reqHeader = request.headers[ACCESS_CONTROL_REQUEST_HEADERS] + if (reqHeader != null) { + headers[ACCESS_CONTROL_ALLOW_HEADERS] = reqHeader + } + headers[ACCESS_CONTROL_ALLOW_METHODS] = allowedMethods.joinToString(", ") + } + headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin + headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true" + log.debug("adding CORS response headers") + response.copy(headers = headers.build()) + } else { + response + } + } +} diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt index de0ec60c..903ffc63 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRequestHandler.kt @@ -10,7 +10,6 @@ import com.nimbusds.oauth2.sdk.GrantType.REFRESH_TOKEN import com.nimbusds.oauth2.sdk.OAuth2Error import com.nimbusds.oauth2.sdk.ParseException import com.nimbusds.openid.connect.sdk.AuthenticationRequest -import io.netty.handler.codec.http.HttpHeaderNames import mu.KotlinLogging import no.nav.security.mock.oauth2.OAuth2Config import no.nav.security.mock.oauth2.OAuth2Exception @@ -38,7 +37,6 @@ import no.nav.security.mock.oauth2.login.LoginRequestHandler import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback import no.nav.security.mock.oauth2.token.OAuth2TokenCallback import no.nav.security.mock.oauth2.userinfo.userInfo -import okhttp3.Headers import java.net.URLEncoder import java.nio.charset.Charset import java.util.concurrent.BlockingQueue @@ -75,6 +73,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) { val authorizationServer: Route = routes { exceptionHandler(exceptionHandler) + interceptors(CorsInterceptor()) wellKnown() jwks() authorization() @@ -139,16 +138,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) { } } - private fun Route.Builder.preflight() = options { - OAuth2HttpResponse( - status = 200, - headers = Headers.headersOf( - HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*", - HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString(), "*", - HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString(), "*" - ) - ) - } + private fun Route.Builder.preflight() = options { OAuth2HttpResponse(status = 204) } private fun tokenCallbackFromQueueOrDefault(issuerId: String): OAuth2TokenCallback = when (issuerId) { diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt index 623c6af2..cd91d4c2 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpResponse.kt @@ -63,7 +63,6 @@ data class OAuth2TokenResponse( fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse( headers = Headers.headersOf( HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8", - HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*" ), status = 200, body = when (anyObject) { @@ -78,7 +77,6 @@ fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse( fun html(content: String): OAuth2HttpResponse = OAuth2HttpResponse( headers = Headers.headersOf( HttpHeaderNames.CONTENT_TYPE.toString(), "text/html;charset=UTF-8", - HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*" ), status = 200, body = content @@ -116,7 +114,6 @@ fun oauth2Error(error: ErrorObject): OAuth2HttpResponse { return OAuth2HttpResponse( headers = Headers.headersOf( HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8", - HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*" ), status = responseCode, body = objectMapper diff --git a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouter.kt b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouter.kt index dd9ddbdd..83d00f7c 100644 --- a/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouter.kt +++ b/src/main/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouter.kt @@ -6,18 +6,35 @@ import no.nav.security.mock.oauth2.extensions.endsWith private val log = KotlinLogging.logger { } typealias RequestHandler = (OAuth2HttpRequest) -> OAuth2HttpResponse -internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse + +interface Interceptor + +fun interface RequestInterceptor : Interceptor { + fun intercept(request: OAuth2HttpRequest): OAuth2HttpRequest +} + +fun interface ResponseInterceptor : Interceptor { + fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse +} interface Route : RequestHandler { + fun match(request: OAuth2HttpRequest): Boolean class Builder { private val routes: MutableList = mutableListOf() + private val interceptors: MutableList = mutableListOf() private var exceptionHandler: ExceptionHandler = { _, throwable -> throw throwable } + fun interceptors(vararg interceptor: Interceptor) = apply { + interceptor.forEach { + interceptors.add(it) + } + } + fun attach(vararg route: Route) = apply { route.forEach { routes.add(it) @@ -56,40 +73,64 @@ interface Route : RequestHandler { routes.add(routeFromPathAndMethod(path, method, requestHandler)) } - fun build(): Route = object : PathRoute { - override fun matchPath(request: OAuth2HttpRequest): Boolean = - routes.any { it.matchPath(request) } - - override fun match(request: OAuth2HttpRequest): Boolean = - routes.firstOrNull { it.match(request) } != null - - override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse = - try { - routes.firstOrNull { it.match(request) }?.invoke(request) ?: noMatch(request) - } catch (t: Throwable) { - exceptionHandler(request, t) - } - - override fun toString(): String = routes.toString() - - private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse { - log.debug("no route matching url=${request.url} with method=${request.method}") - return if (matchPath(request)) { - methodNotAllowed() - } else { - notFound("no routes found") - } - } - - private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false - } + fun build(): Route = PathRouter(routes, interceptors, exceptionHandler) } } +internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse + internal interface PathRoute : Route { fun matchPath(request: OAuth2HttpRequest): Boolean } +internal class PathRouter( + private val routes: MutableList, + private val interceptors: MutableList, + private val exceptionHandler: ExceptionHandler, +) : PathRoute { + + override fun matchPath(request: OAuth2HttpRequest): Boolean = routes.any { it.matchPath(request) } + override fun match(request: OAuth2HttpRequest): Boolean = routes.firstOrNull { it.match(request) } != null + + override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse = runCatching { + routes.findHandler(request).invokeWith(request, interceptors) + }.getOrElse { + exceptionHandler(request, it) + } + + override fun toString(): String = routes.toString() + + private fun MutableList.findHandler(request: OAuth2HttpRequest): RequestHandler = + this.firstOrNull { it.match(request) } ?: { req -> noMatch(req) } + + private fun RequestHandler.invokeWith(request: OAuth2HttpRequest, interceptors: MutableList): OAuth2HttpResponse { + return if (interceptors.size > 0) { + + val filteredRequest = interceptors.filterIsInstance().fold(request) { next, interceptor -> + interceptor.intercept(next) + } + val res = this.invoke(filteredRequest) + val filteredResponse = interceptors.filterIsInstance().fold(res.copy()) { next, interceptor -> + interceptor.intercept(request, next) + } + filteredResponse + } else { + this.invoke(request) + } + } + + private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse { + log.debug("no route matching url=${request.url} with method=${request.method}") + return if (matchPath(request)) { + methodNotAllowed() + } else { + notFound("no routes found") + } + } + + private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false +} + fun routes(vararg route: Route): Route = routes { attach(*route) } diff --git a/src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt b/src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt index 862a3831..88696bc5 100644 --- a/src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt +++ b/src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt @@ -14,6 +14,7 @@ import io.ktor.client.request.header import io.ktor.http.Headers import io.ktor.http.Parameters import io.ktor.http.headersOf +import io.ktor.util.InternalAPI import java.nio.charset.StandardCharsets import java.security.KeyPair import java.security.interfaces.RSAPrivateKey @@ -33,6 +34,7 @@ val httpClient = HttpClient(CIO) { } } +@OptIn(InternalAPI::class) suspend fun HttpClient.tokenRequest(url: String, auth: Auth, params: Map) = submitForm( url = url, diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/CorsHeadersIntegrationTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/CorsHeadersIntegrationTest.kt index 8ccf2fa9..e4c93c50 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/e2e/CorsHeadersIntegrationTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/e2e/CorsHeadersIntegrationTest.kt @@ -3,36 +3,54 @@ package no.nav.security.mock.oauth2.e2e import com.nimbusds.oauth2.sdk.GrantType import io.kotest.assertions.asClue import io.kotest.matchers.shouldBe -import io.netty.handler.codec.http.HttpHeaderNames +import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS +import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_HEADERS +import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_METHODS +import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN +import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_REQUEST_HEADERS import no.nav.security.mock.oauth2.testutils.client import no.nav.security.mock.oauth2.testutils.get import no.nav.security.mock.oauth2.testutils.options import no.nav.security.mock.oauth2.testutils.tokenRequest import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback import no.nav.security.mock.oauth2.withMockOAuth2Server +import okhttp3.Headers import org.junit.jupiter.api.Test class CorsHeadersIntegrationTest { private val client = client() + private val origin = "https://theorigin" + @Test - fun `preflight response should allow all origin, all methods and all headers`() { + fun `preflight response should allow specific origin, methods and headers`() { withMockOAuth2Server { - client.options(this.baseUrl()).asClue { - it.code shouldBe 200 - it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*" - it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString()] shouldBe "*" - it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString()] shouldBe "*" + client.options( + this.baseUrl(), + Headers.headersOf( + "origin", origin, + ACCESS_CONTROL_REQUEST_HEADERS, "X-MY-HEADER" + ) + ).asClue { + it.code shouldBe 204 + it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin + it.headers[ACCESS_CONTROL_ALLOW_METHODS] shouldBe "POST, GET, OPTIONS" + it.headers[ACCESS_CONTROL_ALLOW_HEADERS] shouldBe "X-MY-HEADER" + it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true" } } } @Test - fun `wellknown response should allow all origins`() { + fun `wellknown response should allow origin`() { withMockOAuth2Server { - client.get(this.wellKnownUrl("issuer")).asClue { + client.get( + this.wellKnownUrl("issuer"), + Headers.headersOf("origin", origin) + ).asClue { it.code shouldBe 200 - it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*" + it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin + it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true" } } } @@ -40,9 +58,13 @@ class CorsHeadersIntegrationTest { @Test fun `jwks response should allow all origins`() { withMockOAuth2Server { - client.get(this.jwksUrl("issuer")).asClue { + client.get( + this.jwksUrl("issuer"), + Headers.headersOf("origin", origin) + ).asClue { it.code shouldBe 200 - it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*" + it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin + it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true" } } } @@ -56,6 +78,7 @@ class CorsHeadersIntegrationTest { val response = client.tokenRequest( this.tokenEndpointUrl(issuerId), + Headers.headersOf("origin", origin), mapOf( "grant_type" to GrantType.REFRESH_TOKEN.value, "refresh_token" to "canbewhatever", @@ -65,7 +88,8 @@ class CorsHeadersIntegrationTest { ) response.code shouldBe 200 - response.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*" + response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin + response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true" } } } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouterTest.kt b/src/test/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouterTest.kt index 4cbd09ea..6ec3690d 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouterTest.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/http/OAuth2HttpRouterTest.kt @@ -1,5 +1,8 @@ package no.nav.security.mock.oauth2.http +import io.kotest.assertions.asClue +import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainAll import io.kotest.matchers.shouldBe import okhttp3.Headers import okhttp3.HttpUrl.Companion.toHttpUrl @@ -24,6 +27,7 @@ internal class OAuth2HttpRouterTest { routes.invoke(options("/something/shouldmatch")).body shouldBe "OPTIONS" routes.invoke(get("/something/shouldmatch")).body shouldBe "GET" } + @Test fun `routes from route builder should be matched`() { val route = routes { @@ -53,6 +57,46 @@ internal class OAuth2HttpRouterTest { finalRoutes.invoke(post("/first/second")).status shouldBe 405 finalRoutes.invoke(get("/notfound")).status shouldBe 404 } + + @Test + fun `request and response interceptors should be applied on every route`() { + + val routes = routes { + interceptors( + RequestInterceptor { + val headers = it.headers.newBuilder().add("yolo", "forever").build() + it.copy(headers = headers) + }, + ResponseInterceptor { _, response -> + val headers = response.headers.newBuilder().add("fromInterceptor", "fromInterceptor").build() + response.copy(headers = headers) + } + ) + get("/1") { + it.headers shouldContain ("yolo" to "forever") + ok("1") + } + get("/2") { + it.headers shouldContain ("yolo" to "forever") + ok("2") + } + } + routes.invoke(get("/1")).asClue { + it.headers shouldContainAll listOf( + "Content-Type" to "text/plain", + "fromInterceptor" to "fromInterceptor" + ) + it.body shouldBe "1" + } + routes.invoke(get("/2")).asClue { + it.headers shouldContainAll listOf( + "Content-Type" to "text/plain", + "fromInterceptor" to "fromInterceptor" + ) + it.body shouldBe "2" + } + } + private fun get(path: String) = request("http://localhost$path", "GET") private fun post(path: String, body: String? = "na") = request("http://localhost$path", "POST", body) private fun options(path: String, body: String? = "na") = request("http://localhost$path", "OPTIONS", body) @@ -65,5 +109,9 @@ internal class OAuth2HttpRouterTest { body ) - private fun ok(body: String? = null) = OAuth2HttpResponse(status = 200, body = body) + private fun ok(body: String? = null) = OAuth2HttpResponse( + headers = Headers.headersOf("Content-Type", "text/plain"), + status = 200, + body = body + ) } diff --git a/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Http.kt b/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Http.kt index de8b5859..eb4a6922 100644 --- a/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Http.kt +++ b/src/test/kotlin/no/nav/security/mock/oauth2/testutils/Http.kt @@ -96,11 +96,13 @@ fun OkHttpClient.get( ).execute() fun OkHttpClient.options( - url: HttpUrl + url: HttpUrl, + headers: Headers = Headers.headersOf(), ): Response = this.newCall( Request.Builder().options( - url + url, + headers ) ).execute() @@ -121,8 +123,9 @@ fun Request.Builder.post(url: HttpUrl, headers: Headers, parameters: Map