Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor frame handler improvements #514

Merged
merged 2 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/main/scala/netty/FrameHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@ package lila.ws
package netty

import com.typesafe.scalalogging.Logger
import io.netty.channel.Channel
import io.netty.channel.ChannelHandlerContext
import io.netty.channel.SimpleChannelInboundHandler
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame
import io.netty.handler.codec.http.websocketx.WebSocketFrame
import io.netty.handler.codec.http.websocketx.*

import ipc.ClientOut
import io.netty.handler.codec.http.websocketx.PongWebSocketFrame
import io.netty.channel.Channel

final private class FrameHandler(connector: ActorChannelConnector)(using Executor)
extends SimpleChannelInboundHandler[WebSocketFrame]:
Expand All @@ -26,7 +24,9 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo
val txt = frame.text
if txt.nonEmpty then
val endpoint = ctx.channel.attr(key.endpoint).get
if endpoint == null || endpoint.rateLimit(txt) then
if endpoint != null && !endpoint.rateLimit(txt) then
shutdown(ctx.channel, 1008, "rate limit exceeded")
else
ClientOut parse txt foreach:

case ClientOut.Switch(uri) if endpoint != null =>
Expand All @@ -35,25 +35,35 @@ final private class FrameHandler(connector: ActorChannelConnector)(using Executo
case ClientOut.Unexpected(msg) =>
Monitor.clientOutUnexpected.increment()
logger.info(s"Unexpected $msg")
shutdown(ctx.channel, 1011, "unexpected message")

case ClientOut.WrongHole =>
Monitor.clientOutWrongHole.increment()
// choice of hole is a personal decision

case out => withClientOf(ctx.channel)(_ tell out)

case frame: PongWebSocketFrame =>
val lagMillis = (System.currentTimeMillis() - frame.content().getLong(0)).toInt
val pong = ClientOut.RoundPongFrame(lagMillis)
Option(ctx.channel.attr(key.client).get) foreach { _.value foreach { _ foreach (_ ! pong) } }

case frame =>
logger.info("unsupported frame type: " + frame.getClass.getName)
shutdown(ctx.channel, 1003, s"${frame.getClass.getName} unsupported")

private def withClientOf(channel: Channel)(f: Client => Unit): Unit =
Option(channel.attr(key.client).get) match
case Some(clientFu) =>
clientFu.value match
case Some(client) => client foreach f
case None => clientFu foreach f
case None => logger.warn(s"No client actor on channel $channel")
case None =>
logger.warn(s"No client actor on channel $channel")
shutdown(channel, 1011, s"no actor on $channel")

private def shutdown(channel: Channel, code: Int, reason: String): Unit =
channel.writeAndFlush(CloseWebSocketFrame(code, reason)).addListener(_ => channel.close())

private object FrameHandler:

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/util/RateLimit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ final class RateLimit(
private var logged: Boolean = false

def apply(msg: => String = ""): Boolean =
if credits > 0 then
if msg == "null" then true
else if credits > 0 then
credits -= 1
true
else if clearAt < nowMillis then
Expand Down