diff --git a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala index ba9b0a0579..1689d01d05 100644 --- a/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala +++ b/server/core/src/main/scala/sttp/tapir/server/interpreter/ServerInterpreter.scala @@ -29,12 +29,14 @@ class ServerInterpreter[R, F[_], B, S]( ): RequestHandler[F, R, B] = { is match { case Nil => RequestHandler.from { (request, ses, _) => firstNotNone(request, ses, eisAcc.reverse, Nil) } - case (i: RequestInterceptor[F]) :: tail => - i( - responder, - { ei => RequestHandler.from { (request, ses, _) => callInterceptors(tail, ei :: eisAcc, responder).apply(request, ses) } } - ) - case (ei: EndpointInterceptor[F]) :: tail => callInterceptors(tail, ei :: eisAcc, responder) + case is => is.head match { + case ei: EndpointInterceptor[F] => callInterceptors(is.tail, ei :: eisAcc, responder) + case i: RequestInterceptor[F] => + i( + responder, + { ei => RequestHandler.from { (request, ses, _) => callInterceptors(is.tail, ei :: eisAcc, responder).apply(request, ses) } } + ) + } } } @@ -47,7 +49,8 @@ class ServerInterpreter[R, F[_], B, S]( ): F[RequestResult[B]] = ses match { case Nil => (RequestResult.Failure(accumulatedFailureContexts.reverse): RequestResult[B]).unit - case se :: tail => + case ses => + val se = ses.head tryServerEndpoint[se.SECURITY_INPUT, se.PRINCIPAL, se.INPUT, se.ERROR_OUTPUT, se.OUTPUT]( request, se, @@ -55,28 +58,29 @@ class ServerInterpreter[R, F[_], B, S]( ) .flatMap { case RequestResult.Failure(failureContexts) => - firstNotNone(request, tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts) + firstNotNone(request, ses.tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts) case r => r.unit } } - private def tryServerEndpoint[A, U, I, E, O]( - request: ServerRequest, - se: ServerEndpoint.Full[A, U, I, E, O, R, F], - endpointInterceptors: List[EndpointInterceptor[F]] - ): F[RequestResult[B]] = { - val defaultSecurityFailureResponse = + private val defaultSecurityFailureResponse = ServerResponse[B](StatusCode.InternalServerError, Nil, None, None).unit - def endpointHandler(securityFailureResponse: => F[ServerResponse[B]]): EndpointHandler[F, B] = + private def endpointHandler(securityFailureResponse: => F[ServerResponse[B]], endpointInterceptors: List[EndpointInterceptor[F]]): EndpointHandler[F, B] = endpointInterceptors.foldRight(defaultEndpointHandler(securityFailureResponse)) { case (interceptor, handler) => interceptor(responder(defaultSuccessStatusCode), handler) } - def resultOrValueFrom = new ResultOrValueFrom { + private def tryServerEndpoint[A, U, I, E, O]( + request: ServerRequest, + se: ServerEndpoint.Full[A, U, I, E, O, R, F], + endpointInterceptors: List[EndpointInterceptor[F]] + ): F[RequestResult[B]] = { + + val resultOrValueFrom = new ResultOrValueFrom { def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]] = { val decodeFailureContext = interceptor.DecodeFailureContext(se.endpoint, input, failure, request) - endpointHandler(defaultSecurityFailureResponse) + endpointHandler(defaultSecurityFailureResponse, endpointInterceptors) .onDecodeFailure(decodeFailureContext) .map { case Some(response) => RequestResult.Response(response) @@ -105,7 +109,7 @@ class ServerInterpreter[R, F[_], B, S]( // 4. running the security logic securityLogicResult <- ResultOrValue( se.securityLogic(monad)(a).map(Right(_): Either[RequestResult[B], Either[E, U]]).handleError { case t: Throwable => - endpointHandler(monad.error(t)) + endpointHandler(monad.error(t), endpointInterceptors) .onSecurityFailure(SecurityFailureContext(se, a, request)) .map(r => Left(RequestResult.Response(r)): Either[RequestResult[B], Either[E, U]]) } @@ -113,7 +117,7 @@ class ServerInterpreter[R, F[_], B, S]( response <- securityLogicResult match { case Left(e) => resultOrValueFrom.value( - endpointHandler(responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e))) + endpointHandler(responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e)), endpointInterceptors) .onSecurityFailure(SecurityFailureContext(se, a, request)) .map(r => RequestResult.Response(r): RequestResult[B]) ) @@ -124,7 +128,7 @@ class ServerInterpreter[R, F[_], B, S]( values <- resultOrValueFrom(decodeBody(request, inputValues)) params <- resultOrValueFrom(InputValue(se.endpoint.input, values)) response <- resultOrValueFrom.value( - endpointHandler(defaultSecurityFailureResponse) + endpointHandler(defaultSecurityFailureResponse, endpointInterceptors) .onDecodeSuccess(interceptor.DecodeSuccessContext(se, a, u, params.asAny.asInstanceOf[I], request)) .map(r => RequestResult.Response(r): RequestResult[B]) )