Skip to content

Commit

Permalink
Fixed handling of CORS within Vert.x (#4232)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiuszkierat authored Jan 9, 2025
1 parent 916d9ff commit 9ee00fe
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 7 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,7 @@ lazy val examples: ProjectMatrix = (projectMatrix in file("examples"))
sttpStubServer,
swaggerUiBundle,
redocBundle,
vertxServer,
zioHttpServer,
zioJson,
zioMetrics
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// {cat=Security; effects=Future; server=Vert.x}: CORS interceptor

//> using dep com.softwaremill.sttp.tapir::tapir-vertx-server:1.11.11
//> using dep com.softwaremill.sttp.client3::core:3.10.2

package sttp.tapir.examples.security

import io.vertx.core.Vertx
import io.vertx.ext.web.*
import sttp.client3.*
import sttp.model.headers.Origin
import sttp.model.{Header, HeaderNames, Method, StatusCode}
import sttp.tapir.*
import sttp.tapir.server.interceptor.cors.{CORSConfig, CORSInterceptor}
import sttp.tapir.server.vertx.VertxFutureServerInterpreter.*
import sttp.tapir.server.vertx.{VertxFutureServerInterpreter, VertxFutureServerOptions}

import scala.concurrent.duration.*
import scala.concurrent.{Await, ExecutionContext, Future}

@main def corsInterceptorVertxServer() =
given ExecutionContext = scala.concurrent.ExecutionContext.Implicits.global
val vertx = Vertx.vertx()

val server = vertx.createHttpServer()
val router = Router.router(vertx)

val myEndpoint = endpoint.get
.in("path")
.out(plainBody[String])
.serverLogic(_ => Future(Right("OK")))

val corsInterceptor = VertxFutureServerOptions.customiseInterceptors
.corsInterceptor(
CORSInterceptor.customOrThrow(
CORSConfig.default
.allowOrigin(Origin.Host("http", "my.origin"))
.allowMethods(Method.GET)
)
)
.options

val attach = VertxFutureServerInterpreter(corsInterceptor).route(myEndpoint)
attach(router)

// starting the server
val bindAndCheck = server.requestHandler(router).listen(9000).asScala.map { binding =>
val backend = HttpClientSyncBackend()

// Sending preflight request with allowed origin
val preflightResponse = basicRequest
.options(uri"http://localhost:9000/path")
.headers(
Header.origin(Origin.Host("http", "my.origin")),
Header.accessControlRequestMethod(Method.GET)
)
.send(backend)

assert(preflightResponse.code == StatusCode.NoContent)
assert(preflightResponse.headers.contains(Header.accessControlAllowOrigin("http://my.origin")))
assert(preflightResponse.headers.contains(Header.accessControlAllowMethods(Method.GET)))

println("Got expected response for preflight request")

// Sending preflight request with disallowed origin
val preflightResponseForDisallowedOrigin = basicRequest
.options(uri"http://localhost:9000/path")
.headers(
Header.origin(Origin.Host("http", "disallowed.com")),
Header.accessControlRequestMethod(Method.GET)
)
.send(backend)

// Check response does not contain allowed origin header
assert(preflightResponseForDisallowedOrigin.code == StatusCode.NoContent)
assert(!preflightResponseForDisallowedOrigin.headers.contains(Header.accessControlAllowOrigin("http://example.com")))

println("Got expected response for preflight request for wrong origin. No allowed origin header in response")

// Sending regular request from allowed origin
val requestResponse = basicRequest
.response(asStringAlways)
.get(uri"http://localhost:9000/path")
.headers(Header.origin(Origin.Host("http", "my.origin")))
.send(backend)

assert(requestResponse.code == StatusCode.Ok)
assert(requestResponse.body == "OK")
assert(requestResponse.headers.contains(Header.vary(HeaderNames.Origin)))
assert(requestResponse.headers.contains(Header.accessControlAllowOrigin("http://my.origin")))

println("Got expected response for regular request")

binding
}

Await.result(bindAndCheck.flatMap(_.close().asScala), 1.minute)
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ trait VertxCatsServerInterpreter[F[_]] extends CommonServerInterpreter with Vert
def route(
e: ServerEndpoint[Fs2Streams[F] with WebSockets, F]
): Router => Route = { router =>
val routeDef = extractRouteDefinition(e.endpoint)
val readStreamCompatible = fs2ReadStreamCompatible(vertxCatsServerOptions)
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxCatsServerOptions)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxCatsServerOptions)
.foreach(_.handler(endpointHandler(e, readStreamCompatible)))
mountWithDefaultHandlers(e)(router, routeDef, vertxCatsServerOptions)
.handler(endpointHandler(e, readStreamCompatible))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ package sttp.tapir.server.vertx

import io.vertx.core.{Handler, Future => VFuture}
import io.vertx.ext.web.{Route, Router, RoutingContext}
import sttp.monad.FutureMonad
import sttp.capabilities.WebSockets
import sttp.monad.FutureMonad
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.RequestResult
import sttp.tapir.server.interpreter.{BodyListener, ServerInterpreter}
Expand All @@ -26,7 +26,10 @@ trait VertxFutureServerInterpreter extends CommonServerInterpreter with VertxErr
* A function, that given a router, will attach this endpoint to it
*/
def route[A, U, I, E, O](e: ServerEndpoint[VertxStreams with WebSockets, Future]): Router => Route = { router =>
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxFutureServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxFutureServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e)(router, routeDef, vertxFutureServerOptions)
.handler(endpointHandler(e))
}

Expand All @@ -37,7 +40,10 @@ trait VertxFutureServerInterpreter extends CommonServerInterpreter with VertxErr
* A function, that given a router, will attach this endpoint to it
*/
def blockingRoute(e: ServerEndpoint[VertxStreams with WebSockets, Future]): Router => Route = { router =>
mountWithDefaultHandlers(e)(router, extractRouteDefinition(e.endpoint), vertxFutureServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e)(router, routeDef, vertxFutureServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e)(router, routeDef, vertxFutureServerOptions)
.blockingHandler(endpointHandler(e))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,45 @@
package sttp.tapir.server.vertx.interpreters

import io.vertx.core.http.HttpMethod._
import io.vertx.ext.web.{Route, Router}
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.interceptor.Interceptor
import sttp.tapir.server.interceptor.cors.CORSInterceptor
import sttp.tapir.server.vertx.VertxServerOptions
import sttp.tapir.server.vertx.handlers.attachDefaultHandlers
import sttp.tapir.server.vertx.routing.PathMapping.{RouteDefinition, createRoute}

trait CommonServerInterpreter {

/** Checks if a CORS interceptor is defined in the server options and creates an OPTIONS route if it is.
*
* Vert.x will signal a 405 error if a route matches the path, but doesn’t match the HTTP Method. So if CORS is defined, we additionally
* register OPTIONS route which accepts the preflight requests.
*
* @return
* An optional Route. If a CORS interceptor is defined, an OPTIONS route is created and returned. Otherwise, None is returned.
*/
protected def optionsRouteIfCORSDefined[C, F[_]](
e: ServerEndpoint[C, F]
)(router: Router, routeDef: RouteDefinition, serverOptions: VertxServerOptions[F]): Option[Route] = {
def isCORSInterceptorDefined(interceptors: List[Interceptor[F]]): Boolean = {
interceptors.collectFirst { case ci: CORSInterceptor[F] => ci }.nonEmpty
}

def createOptionsRoute(router: Router, route: RouteDefinition): Option[Route] =
route match {
case (Some(method), path) if Set(GET, HEAD, POST, PUT, DELETE).contains(method) =>
Some(router.options(path))
case (None, path) => Some(router.options(path))
case _ => None
}

if (isCORSInterceptorDefined(serverOptions.interceptors)) {
createOptionsRoute(router, routeDef)
} else
None
}

protected def mountWithDefaultHandlers[C, F[_]](e: ServerEndpoint[C, F])(
router: Router,
routeDef: RouteDefinition,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package sttp.tapir.server.vertx.routing

import io.vertx.core.http.HttpMethod
import io.vertx.ext.web.{Route, Router}
import sttp.tapir.{AnyEndpoint, EndpointInput}
import sttp.tapir.EndpointInput.PathCapture
import sttp.tapir.internal._
import sttp.tapir.{AnyEndpoint, EndpointInput}

object PathMapping {

Expand Down Expand Up @@ -49,5 +49,4 @@ object PathMapping {
.mkString
if (path.isEmpty) "/*" else path
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ trait VertxZioServerInterpreter[R] extends CommonServerInterpreter with VertxErr
def route[R2](e: ZServerEndpoint[R2, ZioStreams with WebSockets])(implicit
runtime: Runtime[R & R2]
): Router => Route = { router =>
mountWithDefaultHandlers(e.widen)(router, extractRouteDefinition(e.endpoint), vertxZioServerOptions)
val routeDef = extractRouteDefinition(e.endpoint)
optionsRouteIfCORSDefined(e.widen)(router, routeDef, vertxZioServerOptions)
.foreach(_.handler(endpointHandler(e)))
mountWithDefaultHandlers(e.widen)(router, routeDef, vertxZioServerOptions)
.handler(endpointHandler(e))
}

Expand Down

0 comments on commit 9ee00fe

Please sign in to comment.