From c2446fa5f557aabedbcc15c7c9d9543ab2eacb5d Mon Sep 17 00:00:00 2001 From: Thanh Le Date: Fri, 23 Feb 2024 09:44:57 +0700 Subject: [PATCH] Catch exceptions when parsing request header Also make all class variables lazy just in case we don't use them --- src/main/scala/netty/RequestHandler.scala | 46 ++++++++++++----------- src/main/scala/util/RequestHeader.scala | 13 +++++-- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/src/main/scala/netty/RequestHandler.scala b/src/main/scala/netty/RequestHandler.scala index 2b3f7882..c42aae12 100644 --- a/src/main/scala/netty/RequestHandler.scala +++ b/src/main/scala/netty/RequestHandler.scala @@ -18,28 +18,30 @@ final private class RequestHandler(router: Router)(using Executor) ): override def channelRead0(ctx: ChannelHandlerContext, req: FullHttpRequest): Unit = - val request = util.RequestHeader(RequestUri(req.uri), req.headers) - router(request) foreach: - case Left(status) => - sendErrorResponse( - ctx, - DefaultFullHttpResponse(req.protocolVersion(), status, Unpooled.EMPTY_BUFFER) - ) - req.release() - case Right(_) if req.method != HttpMethod.GET => - val response = DefaultFullHttpResponse( - req.protocolVersion(), - HttpResponseStatus.METHOD_NOT_ALLOWED, - Unpooled.EMPTY_BUFFER - ) - response.headers.set(HttpHeaderNames.ALLOW, "GET") - sendErrorResponse(ctx, response) - req.release() - case Right(endpoint) => - Monitor.mobile.connect(request) - // Forward to ProtocolHandler with endpoint attribute - ctx.channel.attr(ProtocolHandler.key.endpoint).set(endpoint) - ctx.fireChannelRead(req) + util + .RequestHeader(RequestUri(req.uri), req.headers) + .foreach: request => + router(request).foreach: + case Left(status) => + sendErrorResponse( + ctx, + DefaultFullHttpResponse(req.protocolVersion(), status, Unpooled.EMPTY_BUFFER) + ) + req.release() + case Right(_) if req.method != HttpMethod.GET => + val response = DefaultFullHttpResponse( + req.protocolVersion(), + HttpResponseStatus.METHOD_NOT_ALLOWED, + Unpooled.EMPTY_BUFFER + ) + response.headers.set(HttpHeaderNames.ALLOW, "GET") + sendErrorResponse(ctx, response) + req.release() + case Right(endpoint) => + Monitor.mobile.connect(request) + // Forward to ProtocolHandler with endpoint attribute + ctx.channel.attr(ProtocolHandler.key.endpoint).set(endpoint) + ctx.fireChannelRead(req) private def sendErrorResponse(ctx: ChannelHandlerContext, response: DefaultFullHttpResponse) = val f = ctx.writeAndFlush(response) diff --git a/src/main/scala/util/RequestHeader.scala b/src/main/scala/util/RequestHeader.scala index 04381ea2..74be1238 100644 --- a/src/main/scala/util/RequestHeader.scala +++ b/src/main/scala/util/RequestHeader.scala @@ -12,11 +12,11 @@ object RequestUri extends OpaqueString[RequestUri] opaque type Domain = String object Domain extends OpaqueString[Domain] -final class RequestHeader(val uri: RequestUri, headers: HttpHeaders): +final class RequestHeader private (val uri: RequestUri, headers: HttpHeaders): - val (path, parameters) = - val qsd = QueryStringDecoder(uri.value) - (qsd.path, qsd.parameters) + lazy val qsd = QueryStringDecoder(uri.value) + lazy val path = qsd.path + lazy val parameters = qsd.parameters def header(name: CharSequence): Option[String] = Option(headers get name).filter(_.nonEmpty) @@ -56,3 +56,8 @@ final class RequestHeader(val uri: RequestUri, headers: HttpHeaders): def domain = Domain(header(HttpHeaderNames.HOST) getOrElse "?") override def toString = s"$name origin: $origin" + +object RequestHeader: + import cats.syntax.all.* + def apply(uri: RequestUri, headers: HttpHeaders): Either[Throwable, RequestHeader] = + Either.catchNonFatal(new RequestHeader(uri, headers))