diff --git a/src/main/scala/netty/FrameHandler.scala b/src/main/scala/netty/FrameHandler.scala index e0587393..a178e5bf 100644 --- a/src/main/scala/netty/FrameHandler.scala +++ b/src/main/scala/netty/FrameHandler.scala @@ -2,14 +2,12 @@ package lila.ws package netty import com.typesafe.scalalogging.Logger +import io.netty.channel.Channel import io.netty.channel.ChannelHandlerContext import io.netty.channel.SimpleChannelInboundHandler -import io.netty.handler.codec.http.websocketx.TextWebSocketFrame -import io.netty.handler.codec.http.websocketx.WebSocketFrame +import io.netty.handler.codec.http.websocketx.* import ipc.ClientOut -import io.netty.handler.codec.http.websocketx.PongWebSocketFrame -import io.netty.channel.Channel final private class FrameHandler(connector: ActorChannelConnector)(using Executor) extends SimpleChannelInboundHandler[WebSocketFrame]: @@ -26,7 +24,9 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo val txt = frame.text if txt.nonEmpty then val endpoint = ctx.channel.attr(key.endpoint).get - if endpoint == null || endpoint.rateLimit(txt) then + if endpoint != null && !endpoint.rateLimit(txt) then + shutdown(ctx.channel, 1008, "rate limit exceeded") + else ClientOut parse txt foreach: case ClientOut.Switch(uri) if endpoint != null => @@ -35,17 +35,22 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo case ClientOut.Unexpected(msg) => Monitor.clientOutUnexpected.increment() logger.info(s"Unexpected $msg") + shutdown(ctx.channel, 1011, "unexpected message") case ClientOut.WrongHole => Monitor.clientOutWrongHole.increment() + // choice of hole is a personal decision case out => withClientOf(ctx.channel)(_ tell out) + case frame: PongWebSocketFrame => val lagMillis = (System.currentTimeMillis() - frame.content().getLong(0)).toInt val pong = ClientOut.RoundPongFrame(lagMillis) Option(ctx.channel.attr(key.client).get) foreach { _.value foreach { _ foreach (_ ! pong) } } + case frame => logger.info("unsupported frame type: " + frame.getClass.getName) + shutdown(ctx.channel, 1003, s"${frame.getClass.getName} unsupported") private def withClientOf(channel: Channel)(f: Client => Unit): Unit = Option(channel.attr(key.client).get) match @@ -53,7 +58,12 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo clientFu.value match case Some(client) => client foreach f case None => clientFu foreach f - case None => logger.warn(s"No client actor on channel $channel") + case None => + logger.warn(s"No client actor on channel $channel") + shutdown(channel, 1011, s"no actor on $channel") + + private def shutdown(channel: Channel, code: Int, reason: String): Unit = + channel.writeAndFlush(CloseWebSocketFrame(code, reason)).addListener(_ => channel.close()) private object FrameHandler: diff --git a/src/main/scala/util/RateLimit.scala b/src/main/scala/util/RateLimit.scala index c1fe5765..fc449c0a 100644 --- a/src/main/scala/util/RateLimit.scala +++ b/src/main/scala/util/RateLimit.scala @@ -16,7 +16,8 @@ final class RateLimit( private var logged: Boolean = false def apply(msg: => String = ""): Boolean = - if credits > 0 then + if msg == "null" then true + else if credits > 0 then credits -= 1 true else if clearAt < nowMillis then