From c832d40e7c28e79db39aac576e077cb2c89ccb2f Mon Sep 17 00:00:00 2001 From: Jonathan Gamble Date: Sat, 30 Dec 2023 12:24:22 -0600 Subject: [PATCH] close connections when rate limits are exceeded --- src/main/scala/netty/FrameHandler.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/main/scala/netty/FrameHandler.scala b/src/main/scala/netty/FrameHandler.scala index e0587393..9cad5fde 100644 --- a/src/main/scala/netty/FrameHandler.scala +++ b/src/main/scala/netty/FrameHandler.scala @@ -2,14 +2,12 @@ package lila.ws package netty import com.typesafe.scalalogging.Logger +import io.netty.channel.Channel import io.netty.channel.ChannelHandlerContext import io.netty.channel.SimpleChannelInboundHandler -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame -import io.netty.handler.codec.http.websocketx.WebSocketFrame +import io.netty.handler.codec.http.websocketx.* import ipc.ClientOut -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame -import io.netty.channel.Channel final private class FrameHandler(connector: ActorChannelConnector)(using Executor) extends SimpleChannelInboundHandler[WebSocketFrame]: @@ -26,7 +24,11 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo val txt = frame.text if txt.nonEmpty then val endpoint = ctx.channel.attr(key.endpoint).get - if endpoint == null || endpoint.rateLimit(txt) then + if endpoint != null && !endpoint.rateLimit(txt) then + ctx.channel + .writeAndFlush(new CloseWebSocketFrame(1008, "rate limit exceeded")) + .addListener(_ => ctx.channel.close()) + else ClientOut parse txt foreach: case ClientOut.Switch(uri) if endpoint != null =>