Skip to content

Commit

Permalink
Server interpreter optimizations. (#3076)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Kloch <[email protected]>
  • Loading branch information
kamilkloch and Kamil Kloch authored Aug 10, 2023
1 parent 7611ce2 commit d153cdf
Showing 1 changed file with 24 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ class ServerInterpreter[R, F[_], B, S](
): RequestHandler[F, R, B] = {
is match {
case Nil => RequestHandler.from { (request, ses, _) => firstNotNone(request, ses, eisAcc.reverse, Nil) }
case (i: RequestInterceptor[F]) :: tail =>
i(
responder,
{ ei => RequestHandler.from { (request, ses, _) => callInterceptors(tail, ei :: eisAcc, responder).apply(request, ses) } }
)
case (ei: EndpointInterceptor[F]) :: tail => callInterceptors(tail, ei :: eisAcc, responder)
case is => is.head match {
case ei: EndpointInterceptor[F] => callInterceptors(is.tail, ei :: eisAcc, responder)
case i: RequestInterceptor[F] =>
i(
responder,
{ ei => RequestHandler.from { (request, ses, _) => callInterceptors(is.tail, ei :: eisAcc, responder).apply(request, ses) } }
)
}
}
}

Expand All @@ -47,36 +49,38 @@ class ServerInterpreter[R, F[_], B, S](
): F[RequestResult[B]] =
ses match {
case Nil => (RequestResult.Failure(accumulatedFailureContexts.reverse): RequestResult[B]).unit
case se :: tail =>
case ses =>
val se = ses.head
tryServerEndpoint[se.SECURITY_INPUT, se.PRINCIPAL, se.INPUT, se.ERROR_OUTPUT, se.OUTPUT](
request,
se,
endpointInterceptors
)
.flatMap {
case RequestResult.Failure(failureContexts) =>
firstNotNone(request, tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts)
firstNotNone(request, ses.tail, endpointInterceptors, failureContexts ++: accumulatedFailureContexts)
case r => r.unit
}
}

private def tryServerEndpoint[A, U, I, E, O](
request: ServerRequest,
se: ServerEndpoint.Full[A, U, I, E, O, R, F],
endpointInterceptors: List[EndpointInterceptor[F]]
): F[RequestResult[B]] = {
val defaultSecurityFailureResponse =
private val defaultSecurityFailureResponse =
ServerResponse[B](StatusCode.InternalServerError, Nil, None, None).unit

def endpointHandler(securityFailureResponse: => F[ServerResponse[B]]): EndpointHandler[F, B] =
private def endpointHandler(securityFailureResponse: => F[ServerResponse[B]], endpointInterceptors: List[EndpointInterceptor[F]]): EndpointHandler[F, B] =
endpointInterceptors.foldRight(defaultEndpointHandler(securityFailureResponse)) { case (interceptor, handler) =>
interceptor(responder(defaultSuccessStatusCode), handler)
}

def resultOrValueFrom = new ResultOrValueFrom {
private def tryServerEndpoint[A, U, I, E, O](
request: ServerRequest,
se: ServerEndpoint.Full[A, U, I, E, O, R, F],
endpointInterceptors: List[EndpointInterceptor[F]]
): F[RequestResult[B]] = {

val resultOrValueFrom = new ResultOrValueFrom {
def onDecodeFailure(input: EndpointInput[_], failure: DecodeResult.Failure): F[RequestResult[B]] = {
val decodeFailureContext = interceptor.DecodeFailureContext(se.endpoint, input, failure, request)
endpointHandler(defaultSecurityFailureResponse)
endpointHandler(defaultSecurityFailureResponse, endpointInterceptors)
.onDecodeFailure(decodeFailureContext)
.map {
case Some(response) => RequestResult.Response(response)
Expand Down Expand Up @@ -105,15 +109,15 @@ class ServerInterpreter[R, F[_], B, S](
// 4. running the security logic
securityLogicResult <- ResultOrValue(
se.securityLogic(monad)(a).map(Right(_): Either[RequestResult[B], Either[E, U]]).handleError { case t: Throwable =>
endpointHandler(monad.error(t))
endpointHandler(monad.error(t), endpointInterceptors)
.onSecurityFailure(SecurityFailureContext(se, a, request))
.map(r => Left(RequestResult.Response(r)): Either[RequestResult[B], Either[E, U]])
}
)
response <- securityLogicResult match {
case Left(e) =>
resultOrValueFrom.value(
endpointHandler(responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e)))
endpointHandler(responder(defaultErrorStatusCode)(request, model.ValuedEndpointOutput(se.endpoint.errorOutput, e)), endpointInterceptors)
.onSecurityFailure(SecurityFailureContext(se, a, request))
.map(r => RequestResult.Response(r): RequestResult[B])
)
Expand All @@ -124,7 +128,7 @@ class ServerInterpreter[R, F[_], B, S](
values <- resultOrValueFrom(decodeBody(request, inputValues))
params <- resultOrValueFrom(InputValue(se.endpoint.input, values))
response <- resultOrValueFrom.value(
endpointHandler(defaultSecurityFailureResponse)
endpointHandler(defaultSecurityFailureResponse, endpointInterceptors)
.onDecodeSuccess(interceptor.DecodeSuccessContext(se, a, u, params.asAny.asInstanceOf[I], request))
.map(r => RequestResult.Response(r): RequestResult[B])
)
Expand Down

0 comments on commit d153cdf

Please sign in to comment.