From c832d40e7c28e79db39aac576e077cb2c89ccb2f Mon Sep 17 00:00:00 2001 From: Jonathan Gamble Date: Sat, 30 Dec 2023 12:24:22 -0600 Subject: [PATCH 1/2] close connections when rate limits are exceeded --- src/main/scala/netty/FrameHandler.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/main/scala/netty/FrameHandler.scala b/src/main/scala/netty/FrameHandler.scala index e0587393..9cad5fde 100644 --- a/src/main/scala/netty/FrameHandler.scala +++ b/src/main/scala/netty/FrameHandler.scala @@ -2,14 +2,12 @@ package lila.ws package netty import com.typesafe.scalalogging.Logger +import io.netty.channel.Channel import io.netty.channel.ChannelHandlerContext import io.netty.channel.SimpleChannelInboundHandler -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame -import io.netty.handler.codec.http.websocketx.WebSocketFrame +import io.netty.handler.codec.http.websocketx.* import ipc.ClientOut -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame -import io.netty.channel.Channel final private class FrameHandler(connector: ActorChannelConnector)(using Executor) extends SimpleChannelInboundHandler[WebSocketFrame]: @@ -26,7 +24,11 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo val txt = frame.text if txt.nonEmpty then val endpoint = ctx.channel.attr(key.endpoint).get - if endpoint == null || endpoint.rateLimit(txt) then + if endpoint != null && !endpoint.rateLimit(txt) then + ctx.channel + .writeAndFlush(new CloseWebSocketFrame(1008, "rate limit exceeded")) + .addListener(_ => ctx.channel.close()) + else ClientOut parse txt foreach: case ClientOut.Switch(uri) if endpoint != null => From 61cfe67db7e8d3a1509ffe68b0970ef9a42cf382 Mon Sep 17 00:00:00 2001 From: Jonathan Gamble Date: Sat, 30 Dec 2023 13:03:40 -0600 Subject: [PATCH 2/2] remove new --- src/main/scala/netty/FrameHandler.scala | 16 ++++++++++++---- src/main/scala/util/RateLimit.scala | 3 ++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/main/scala/netty/FrameHandler.scala b/src/main/scala/netty/FrameHandler.scala index 9cad5fde..a178e5bf 100644 --- a/src/main/scala/netty/FrameHandler.scala +++ b/src/main/scala/netty/FrameHandler.scala @@ -25,9 +25,7 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo if txt.nonEmpty then val endpoint = ctx.channel.attr(key.endpoint).get if endpoint != null && !endpoint.rateLimit(txt) then - ctx.channel - .writeAndFlush(new CloseWebSocketFrame(1008, "rate limit exceeded")) - .addListener(_ => ctx.channel.close()) + shutdown(ctx.channel, 1008, "rate limit exceeded") else ClientOut parse txt foreach: @@ -37,17 +35,22 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo case ClientOut.Unexpected(msg) => Monitor.clientOutUnexpected.increment() logger.info(s"Unexpected $msg") + shutdown(ctx.channel, 1011, "unexpected message") case ClientOut.WrongHole => Monitor.clientOutWrongHole.increment() + // choice of hole is a personal decision case out => withClientOf(ctx.channel)(_ tell out) + case frame: PongWebSocketFrame => val lagMillis = (System.currentTimeMillis() - frame.content().getLong(0)).toInt val pong = ClientOut.RoundPongFrame(lagMillis) Option(ctx.channel.attr(key.client).get) foreach { _.value foreach { _ foreach (_ ! pong) } } + case frame => logger.info("unsupported frame type: " + frame.getClass.getName) + shutdown(ctx.channel, 1003, s"${frame.getClass.getName} unsupported") private def withClientOf(channel: Channel)(f: Client => Unit): Unit = Option(channel.attr(key.client).get) match @@ -55,7 +58,12 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo clientFu.value match case Some(client) => client foreach f case None => clientFu foreach f - case None => logger.warn(s"No client actor on channel $channel") + case None => + logger.warn(s"No client actor on channel $channel") + shutdown(channel, 1011, s"no actor on $channel") + + private def shutdown(channel: Channel, code: Int, reason: String): Unit = + channel.writeAndFlush(CloseWebSocketFrame(code, reason)).addListener(_ => channel.close()) private object FrameHandler: diff --git a/src/main/scala/util/RateLimit.scala b/src/main/scala/util/RateLimit.scala index c1fe5765..fc449c0a 100644 --- a/src/main/scala/util/RateLimit.scala +++ b/src/main/scala/util/RateLimit.scala @@ -16,7 +16,8 @@ final class RateLimit( private var logged: Boolean = false def apply(msg: => String = ""): Boolean = - if credits > 0 then + if msg == "null" then true + else if credits > 0 then credits -= 1 true else if clearAt < nowMillis then