Skip to content

Commit

Permalink
feature: add support for request/response interceptors, support CORS …
Browse files Browse the repository at this point in the history
…with credentials (#199)

* feat: add request/response interceptor feature to route handling
* feat: move CORS functionality into CorsInterceptor
  • Loading branch information
tommytroen authored Mar 8, 2022
1 parent cf05496 commit dac8b61
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package no.nav.security.mock.oauth2.http

import mu.KotlinLogging

private val log = KotlinLogging.logger {}

class CorsInterceptor(
private val allowedMethods: List<String> = listOf("POST", "GET", "OPTIONS")
) : ResponseInterceptor {

companion object HeaderNames {
const val ORIGIN = "origin"
const val ACCESS_CONTROL_ALLOW_CREDENTIALS = "access-control-allow-credentials"
const val ACCESS_CONTROL_REQUEST_HEADERS = "access-control-request-headers"
const val ACCESS_CONTROL_ALLOW_HEADERS = "access-control-allow-headers"
const val ACCESS_CONTROL_ALLOW_METHODS = "access-control-allow-methods"
const val ACCESS_CONTROL_ALLOW_ORIGIN = "access-control-allow-origin"
}

override fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse {
val origin = request.headers[ORIGIN]
log.debug("intercept response if request origin header is set: $origin")
return if (origin != null) {
val headers = response.headers.newBuilder()
if (request.method == "OPTIONS") {
val reqHeader = request.headers[ACCESS_CONTROL_REQUEST_HEADERS]
if (reqHeader != null) {
headers[ACCESS_CONTROL_ALLOW_HEADERS] = reqHeader
}
headers[ACCESS_CONTROL_ALLOW_METHODS] = allowedMethods.joinToString(", ")
}
headers[ACCESS_CONTROL_ALLOW_ORIGIN] = origin
headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] = "true"
log.debug("adding CORS response headers")
response.copy(headers = headers.build())
} else {
response
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.nimbusds.oauth2.sdk.GrantType.REFRESH_TOKEN
import com.nimbusds.oauth2.sdk.OAuth2Error
import com.nimbusds.oauth2.sdk.ParseException
import com.nimbusds.openid.connect.sdk.AuthenticationRequest
import io.netty.handler.codec.http.HttpHeaderNames
import mu.KotlinLogging
import no.nav.security.mock.oauth2.OAuth2Config
import no.nav.security.mock.oauth2.OAuth2Exception
Expand Down Expand Up @@ -38,7 +37,6 @@ import no.nav.security.mock.oauth2.login.LoginRequestHandler
import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback
import no.nav.security.mock.oauth2.token.OAuth2TokenCallback
import no.nav.security.mock.oauth2.userinfo.userInfo
import okhttp3.Headers
import java.net.URLEncoder
import java.nio.charset.Charset
import java.util.concurrent.BlockingQueue
Expand Down Expand Up @@ -75,6 +73,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) {

val authorizationServer: Route = routes {
exceptionHandler(exceptionHandler)
interceptors(CorsInterceptor())
wellKnown()
jwks()
authorization()
Expand Down Expand Up @@ -139,16 +138,7 @@ class OAuth2HttpRequestHandler(private val config: OAuth2Config) {
}
}

private fun Route.Builder.preflight() = options {
OAuth2HttpResponse(
status = 200,
headers = Headers.headersOf(
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*",
HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString(), "*",
HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString(), "*"
)
)
}
private fun Route.Builder.preflight() = options { OAuth2HttpResponse(status = 204) }

private fun tokenCallbackFromQueueOrDefault(issuerId: String): OAuth2TokenCallback =
when (issuerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ data class OAuth2TokenResponse(
fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse(
headers = Headers.headersOf(
HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8",
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
),
status = 200,
body = when (anyObject) {
Expand All @@ -78,7 +77,6 @@ fun json(anyObject: Any): OAuth2HttpResponse = OAuth2HttpResponse(
fun html(content: String): OAuth2HttpResponse = OAuth2HttpResponse(
headers = Headers.headersOf(
HttpHeaderNames.CONTENT_TYPE.toString(), "text/html;charset=UTF-8",
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
),
status = 200,
body = content
Expand Down Expand Up @@ -116,7 +114,6 @@ fun oauth2Error(error: ErrorObject): OAuth2HttpResponse {
return OAuth2HttpResponse(
headers = Headers.headersOf(
HttpHeaderNames.CONTENT_TYPE.toString(), "application/json;charset=UTF-8",
HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString(), "*"
),
status = responseCode,
body = objectMapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,35 @@ import no.nav.security.mock.oauth2.extensions.endsWith
private val log = KotlinLogging.logger { }

typealias RequestHandler = (OAuth2HttpRequest) -> OAuth2HttpResponse
internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse

interface Interceptor

fun interface RequestInterceptor : Interceptor {
fun intercept(request: OAuth2HttpRequest): OAuth2HttpRequest
}

fun interface ResponseInterceptor : Interceptor {
fun intercept(request: OAuth2HttpRequest, response: OAuth2HttpResponse): OAuth2HttpResponse
}

interface Route : RequestHandler {

fun match(request: OAuth2HttpRequest): Boolean

class Builder {
private val routes: MutableList<Route> = mutableListOf()
private val interceptors: MutableList<Interceptor> = mutableListOf()

private var exceptionHandler: ExceptionHandler = { _, throwable ->
throw throwable
}

fun interceptors(vararg interceptor: Interceptor) = apply {
interceptor.forEach {
interceptors.add(it)
}
}

fun attach(vararg route: Route) = apply {
route.forEach {
routes.add(it)
Expand Down Expand Up @@ -56,40 +73,64 @@ interface Route : RequestHandler {
routes.add(routeFromPathAndMethod(path, method, requestHandler))
}

fun build(): Route = object : PathRoute {
override fun matchPath(request: OAuth2HttpRequest): Boolean =
routes.any { it.matchPath(request) }

override fun match(request: OAuth2HttpRequest): Boolean =
routes.firstOrNull { it.match(request) } != null

override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse =
try {
routes.firstOrNull { it.match(request) }?.invoke(request) ?: noMatch(request)
} catch (t: Throwable) {
exceptionHandler(request, t)
}

override fun toString(): String = routes.toString()

private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse {
log.debug("no route matching url=${request.url} with method=${request.method}")
return if (matchPath(request)) {
methodNotAllowed()
} else {
notFound("no routes found")
}
}

private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false
}
fun build(): Route = PathRouter(routes, interceptors, exceptionHandler)
}
}

internal typealias ExceptionHandler = (OAuth2HttpRequest, Throwable) -> OAuth2HttpResponse

internal interface PathRoute : Route {
fun matchPath(request: OAuth2HttpRequest): Boolean
}

internal class PathRouter(
private val routes: MutableList<Route>,
private val interceptors: MutableList<Interceptor>,
private val exceptionHandler: ExceptionHandler,
) : PathRoute {

override fun matchPath(request: OAuth2HttpRequest): Boolean = routes.any { it.matchPath(request) }
override fun match(request: OAuth2HttpRequest): Boolean = routes.firstOrNull { it.match(request) } != null

override fun invoke(request: OAuth2HttpRequest): OAuth2HttpResponse = runCatching {
routes.findHandler(request).invokeWith(request, interceptors)
}.getOrElse {
exceptionHandler(request, it)
}

override fun toString(): String = routes.toString()

private fun MutableList<Route>.findHandler(request: OAuth2HttpRequest): RequestHandler =
this.firstOrNull { it.match(request) } ?: { req -> noMatch(req) }

private fun RequestHandler.invokeWith(request: OAuth2HttpRequest, interceptors: MutableList<Interceptor>): OAuth2HttpResponse {
return if (interceptors.size > 0) {

val filteredRequest = interceptors.filterIsInstance<RequestInterceptor>().fold(request) { next, interceptor ->
interceptor.intercept(next)
}
val res = this.invoke(filteredRequest)
val filteredResponse = interceptors.filterIsInstance<ResponseInterceptor>().fold(res.copy()) { next, interceptor ->
interceptor.intercept(request, next)
}
filteredResponse
} else {
this.invoke(request)
}
}

private fun noMatch(request: OAuth2HttpRequest): OAuth2HttpResponse {
log.debug("no route matching url=${request.url} with method=${request.method}")
return if (matchPath(request)) {
methodNotAllowed()
} else {
notFound("no routes found")
}
}

private fun Route.matchPath(request: OAuth2HttpRequest): Boolean = (this as? PathRoute)?.matchPath(request) ?: false
}

fun routes(vararg route: Route): Route = routes {
attach(*route)
}
Expand Down
2 changes: 2 additions & 0 deletions src/test/kotlin/examples/kotlin/ktor/client/OAuth2Client.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import io.ktor.client.request.header
import io.ktor.http.Headers
import io.ktor.http.Parameters
import io.ktor.http.headersOf
import io.ktor.util.InternalAPI
import java.nio.charset.StandardCharsets
import java.security.KeyPair
import java.security.interfaces.RSAPrivateKey
Expand All @@ -33,6 +34,7 @@ val httpClient = HttpClient(CIO) {
}
}

@OptIn(InternalAPI::class)
suspend fun HttpClient.tokenRequest(url: String, auth: Auth, params: Map<String, String>) =
submitForm<TokenResponse>(
url = url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,68 @@ package no.nav.security.mock.oauth2.e2e
import com.nimbusds.oauth2.sdk.GrantType
import io.kotest.assertions.asClue
import io.kotest.matchers.shouldBe
import io.netty.handler.codec.http.HttpHeaderNames
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_HEADERS
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_METHODS
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN
import no.nav.security.mock.oauth2.http.CorsInterceptor.HeaderNames.ACCESS_CONTROL_REQUEST_HEADERS
import no.nav.security.mock.oauth2.testutils.client
import no.nav.security.mock.oauth2.testutils.get
import no.nav.security.mock.oauth2.testutils.options
import no.nav.security.mock.oauth2.testutils.tokenRequest
import no.nav.security.mock.oauth2.token.DefaultOAuth2TokenCallback
import no.nav.security.mock.oauth2.withMockOAuth2Server
import okhttp3.Headers
import org.junit.jupiter.api.Test

class CorsHeadersIntegrationTest {
private val client = client()

private val origin = "https://theorigin"

@Test
fun `preflight response should allow all origin, all methods and all headers`() {
fun `preflight response should allow specific origin, methods and headers`() {
withMockOAuth2Server {
client.options(this.baseUrl()).asClue {
it.code shouldBe 200
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS.toString()] shouldBe "*"
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS.toString()] shouldBe "*"
client.options(
this.baseUrl(),
Headers.headersOf(
"origin", origin,
ACCESS_CONTROL_REQUEST_HEADERS, "X-MY-HEADER"
)
).asClue {
it.code shouldBe 204
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
it.headers[ACCESS_CONTROL_ALLOW_METHODS] shouldBe "POST, GET, OPTIONS"
it.headers[ACCESS_CONTROL_ALLOW_HEADERS] shouldBe "X-MY-HEADER"
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
}
}
}

@Test
fun `wellknown response should allow all origins`() {
fun `wellknown response should allow origin`() {
withMockOAuth2Server {
client.get(this.wellKnownUrl("issuer")).asClue {
client.get(
this.wellKnownUrl("issuer"),
Headers.headersOf("origin", origin)
).asClue {
it.code shouldBe 200
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
}
}
}

@Test
fun `jwks response should allow all origins`() {
withMockOAuth2Server {
client.get(this.jwksUrl("issuer")).asClue {
client.get(
this.jwksUrl("issuer"),
Headers.headersOf("origin", origin)
).asClue {
it.code shouldBe 200
it.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
it.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
it.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
}
}
}
Expand All @@ -56,6 +78,7 @@ class CorsHeadersIntegrationTest {

val response = client.tokenRequest(
this.tokenEndpointUrl(issuerId),
Headers.headersOf("origin", origin),
mapOf(
"grant_type" to GrantType.REFRESH_TOKEN.value,
"refresh_token" to "canbewhatever",
Expand All @@ -65,7 +88,8 @@ class CorsHeadersIntegrationTest {
)

response.code shouldBe 200
response.headers[HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN.toString()] shouldBe "*"
response.headers[ACCESS_CONTROL_ALLOW_ORIGIN] shouldBe origin
response.headers[ACCESS_CONTROL_ALLOW_CREDENTIALS] shouldBe "true"
}
}
}
Loading

0 comments on commit dac8b61

Please sign in to comment.