Skip to content

Commit

Permalink
Fix handling of WS Close frame in Netty backends (#3826)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored Jun 10, 2024
1 parent 9c8be81 commit 9cc98ae
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 82 deletions.
2 changes: 1 addition & 1 deletion project/Versions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ object Versions {
val json4s = "4.0.7"
val metrics4Scala = "4.2.9"
val nettyReactiveStreams = "3.0.2"
val ox = "0.1.0"
val ox = "0.2.1"
val reactiveStreams = "1.0.4"
val sprayJson = "1.3.6"
val scalaCheck = "1.18.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP](
sttpFrame
}
val stream: Stream[F, NettyWebSocketFrame] =
optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames)
optionallyConcatenateFrames(o.concatenateFragmentedFrames)(
takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests)(sttpFrames)
)
.map(f =>
o.requests.decode(f) match {
case x: DecodeResult.Value[REQ] => x.v
Expand Down Expand Up @@ -98,10 +100,19 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP](

}

private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] =
private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Stream[F, WebSocketFrame]): Stream[F, WebSocketFrame] =
if (doConcatenate) {
s.mapAccumulate(None: Accumulator)(accumulateFrameState).collect { case (_, Some(f)) => f }
} else s

private def takeUntilCloseFrame(passAlongCloseFrame: Boolean)(s: Stream[F, WebSocketFrame]): Stream[F, WebSocketFrame] =
s.takeWhile(
{
case _: WebSocketFrame.Close => false
case _ => true
},
takeFailure = passAlongCloseFrame
)
}

/** A special wrapper used to override internal logic of fs2, which calls cancel() silently when internal stream failures happen, causing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import sttp.tapir.server.netty.NettyResponseContent.{
ReactivePublisherNettyResponseContent,
ReactiveWebSocketProcessorNettyResponseContent
}
import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler}
import sttp.tapir.server.netty.internal.ws.{WebSocketAutoPingHandler, WebSocketPingPongFrameHandler}
import sttp.tapir.server.netty.{NettyConfig, NettyResponse, NettyServerRequest, Route}

import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -281,10 +281,9 @@ class NettyServerHandler[F[_]](
.addAfter(
ServerCodecHandlerName,
WebSocketControlFrameHandlerName,
new NettyControlFrameHandler(
new WebSocketPingPongFrameHandler(
ignorePong = r.ignorePong,
autoPongOnPing = r.autoPongOnPing,
decodeCloseRequests = r.decodeCloseRequests
autoPongOnPing = r.autoPongOnPing
)
)
r.autoPing.foreach { case (interval, pingMsg) =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package sttp.tapir.server.netty.internal.ws

import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter}
import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, PingWebSocketFrame, PongWebSocketFrame}
import io.netty.handler.codec.http.websocketx.{PingWebSocketFrame, PongWebSocketFrame}
import sttp.tapir.server.netty.internal._

/** Handles Ping, Pong, and Close frames for WebSockets.
/** Handles incoming Ping and Pong frames for WebSockets.
*/
class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean)
extends ChannelInboundHandlerAdapter {
class WebSocketPingPongFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean) extends ChannelInboundHandlerAdapter {

override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
msg match {
Expand All @@ -23,16 +22,6 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec
} else {
val _ = pong.content().release()
}
case close: CloseWebSocketFrame =>
if (decodeCloseRequests) {
// Passing the Close frame for further processing
val _ = ctx.fireChannelRead(close)
} else {
// Responding with Close immediately
val _ = ctx
.writeAndFlush(close)
.close()
}
case other =>
val _ = ctx.fireChannelRead(other)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ private[sync] object OxSourceWebSocketProcessor:
val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] =
(source: Source[NettyWebSocketFrame]) => {
pipe(
optionallyConcatenateFrames(
source
.mapAsView { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
},
o.concatenateFragmentedFrames
optionallyConcatenateFrames(o.concatenateFragmentedFrames)(
takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests)(
source
.mapAsView { f =>
val sttpFrame = nettyFrameToFrame(f)
f.release()
sttpFrame
}
)
)
.mapAsView(f =>
o.requests.decode(f) match {
Expand All @@ -59,12 +60,20 @@ private[sync] object OxSourceWebSocketProcessor:
val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Internal Server Error"))
sub.onError(t)
override def onComplete(): Unit =
val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Bye"))
val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "normal closure"))
sub.onComplete()
}
new OxProcessor(oxDispatcher, frame2FramePipe, wrapSubscriberWithNettyCallback)

private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] =
if doConcatenate then
s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] =
if doConcatenate then s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f }
else s

private def takeUntilCloseFrame(passAlongCloseFrame: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] =
s.takeWhile({
case _: WebSocketFrame.Close => false
case _ => true
},
includeFailed = passAlongCloseFrame
)

Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,13 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {
val released: CompletableFuture[Boolean] = new CompletableFuture[Boolean]()
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)),
"closes supervision scope when client closes Web Socket without getting any responses"
"closes supervision scope when client closes Web Socket"
)((_: Unit) =>
val pipe: OxStreams.Pipe[String, String] = in => {
val outgoing = Channel.bufferedDefault[String]
releaseAfterScope {
released.complete(true).discard
}
outgoing
in
}
Right(pipe)
) { (backend, baseUri) =>
Expand All @@ -67,6 +66,7 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {
for {
_ <- ws.sendText("test1")
_ <- ws.close()
_ <- ws.receiveText()
closeResponse <- ws.eitherClose(ws.receiveText())
} yield closeResponse
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,52 +315,55 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE](
)
else List.empty

val frameConcatenationTests = if (frameConcatenation) List(
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented text frames"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None))
_ <- ws.sendText("f2")
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: f1f2")) }
},
testServer(
endpoint.out(
webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented binary frames"
)((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None))
_ <- ws.sendBinary("frame2-bytes".getBytes())
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) }
}
) else Nil
val frameConcatenationTests =
if (frameConcatenation)
List(
testServer(
endpoint.out(
webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented text frames"
)((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None))
_ <- ws.sendText("f2")
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: f1f2")) }
},
testServer(
endpoint.out(
webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams)
.autoPing(None)
.autoPongOnPing(false)
.concatenateFragmentedFrames(true)
),
"concatenate fragmented binary frames"
)((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None))
_ <- ws.sendBinary("frame2-bytes".getBytes())
r <- ws.receiveText()
_ <- ws.close()
} yield r
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) }
}
)
else Nil

val handlePongTests =
if (handlePong)
Expand Down

0 comments on commit 9cc98ae

Please sign in to comment.