Skip to content

Commit

Permalink
Catch exceptions when parsing request header
Browse files Browse the repository at this point in the history
Also make all class variables lazy just in case we don't use them
  • Loading branch information
lenguyenthanh committed Feb 23, 2024
1 parent 981950b commit c2446fa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
46 changes: 24 additions & 22 deletions src/main/scala/netty/RequestHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,30 @@ final private class RequestHandler(router: Router)(using Executor)
):

override def channelRead0(ctx: ChannelHandlerContext, req: FullHttpRequest): Unit =
val request = util.RequestHeader(RequestUri(req.uri), req.headers)
router(request) foreach:
case Left(status) =>
sendErrorResponse(
ctx,
DefaultFullHttpResponse(req.protocolVersion(), status, Unpooled.EMPTY_BUFFER)
)
req.release()
case Right(_) if req.method != HttpMethod.GET =>
val response = DefaultFullHttpResponse(
req.protocolVersion(),
HttpResponseStatus.METHOD_NOT_ALLOWED,
Unpooled.EMPTY_BUFFER
)
response.headers.set(HttpHeaderNames.ALLOW, "GET")
sendErrorResponse(ctx, response)
req.release()
case Right(endpoint) =>
Monitor.mobile.connect(request)
// Forward to ProtocolHandler with endpoint attribute
ctx.channel.attr(ProtocolHandler.key.endpoint).set(endpoint)
ctx.fireChannelRead(req)
util
.RequestHeader(RequestUri(req.uri), req.headers)
.foreach: request =>
router(request).foreach:
case Left(status) =>
sendErrorResponse(
ctx,
DefaultFullHttpResponse(req.protocolVersion(), status, Unpooled.EMPTY_BUFFER)
)
req.release()
case Right(_) if req.method != HttpMethod.GET =>
val response = DefaultFullHttpResponse(
req.protocolVersion(),
HttpResponseStatus.METHOD_NOT_ALLOWED,
Unpooled.EMPTY_BUFFER
)
response.headers.set(HttpHeaderNames.ALLOW, "GET")
sendErrorResponse(ctx, response)
req.release()
case Right(endpoint) =>
Monitor.mobile.connect(request)
// Forward to ProtocolHandler with endpoint attribute
ctx.channel.attr(ProtocolHandler.key.endpoint).set(endpoint)
ctx.fireChannelRead(req)

private def sendErrorResponse(ctx: ChannelHandlerContext, response: DefaultFullHttpResponse) =
val f = ctx.writeAndFlush(response)
Expand Down
13 changes: 9 additions & 4 deletions src/main/scala/util/RequestHeader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ object RequestUri extends OpaqueString[RequestUri]
opaque type Domain = String
object Domain extends OpaqueString[Domain]

final class RequestHeader(val uri: RequestUri, headers: HttpHeaders):
final class RequestHeader private (val uri: RequestUri, headers: HttpHeaders):

val (path, parameters) =
val qsd = QueryStringDecoder(uri.value)
(qsd.path, qsd.parameters)
lazy val qsd = QueryStringDecoder(uri.value)
lazy val path = qsd.path
lazy val parameters = qsd.parameters

def header(name: CharSequence): Option[String] =
Option(headers get name).filter(_.nonEmpty)
Expand Down Expand Up @@ -56,3 +56,8 @@ final class RequestHeader(val uri: RequestUri, headers: HttpHeaders):
def domain = Domain(header(HttpHeaderNames.HOST) getOrElse "?")

override def toString = s"$name origin: $origin"

object RequestHeader:
import cats.syntax.all.*
def apply(uri: RequestUri, headers: HttpHeaders): Either[Throwable, RequestHeader] =
Either.catchNonFatal(new RequestHeader(uri, headers))

0 comments on commit c2446fa

Please sign in to comment.