From 9cc98ae6e19ba84b808a5ac39c2b8e92e762b6d6 Mon Sep 17 00:00:00 2001 From: Krzysztof Ciesielski Date: Mon, 10 Jun 2024 10:15:05 +0200 Subject: [PATCH] Fix handling of WS Close frame in Netty backends (#3826) --- project/Versions.scala | 2 +- .../internal/WebSocketPipeProcessor.scala | 15 ++- .../netty/internal/NettyServerHandler.scala | 7 +- ...la => WebSocketPingPongFrameHandler.scala} | 17 +--- .../ws/OxSourceWebSocketProcessor.scala | 33 ++++--- .../netty/sync/NettySyncServerTest.scala | 6 +- .../server/tests/ServerWebSocketTests.scala | 95 ++++++++++--------- 7 files changed, 93 insertions(+), 82 deletions(-) rename server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/{WebSocketControlFrameHandler.scala => WebSocketPingPongFrameHandler.scala} (53%) diff --git a/project/Versions.scala b/project/Versions.scala index f5a0683100..8ec5efcd63 100644 --- a/project/Versions.scala +++ b/project/Versions.scala @@ -25,7 +25,7 @@ object Versions { val json4s = "4.0.7" val metrics4Scala = "4.2.9" val nettyReactiveStreams = "3.0.2" - val ox = "0.1.0" + val ox = "0.2.1" val reactiveStreams = "1.0.4" val sprayJson = "1.3.6" val scalaCheck = "1.18.0" diff --git a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala index 3612925382..f248462663 100644 --- a/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala +++ b/server/netty-server/cats/src/main/scala/sttp/tapir/server/netty/cats/internal/WebSocketPipeProcessor.scala @@ -45,7 +45,9 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( sttpFrame } val stream: Stream[F, NettyWebSocketFrame] = - optionallyConcatenateFrames(sttpFrames, o.concatenateFragmentedFrames) + optionallyConcatenateFrames(o.concatenateFragmentedFrames)( + takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests)(sttpFrames) + ) .map(f => o.requests.decode(f) match { case x: DecodeResult.Value[REQ] => x.v @@ -98,10 +100,19 @@ class WebSocketPipeProcessor[F[_]: Async, REQ, RESP]( } - private def optionallyConcatenateFrames(s: Stream[F, WebSocketFrame], doConcatenate: Boolean): Stream[F, WebSocketFrame] = + private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Stream[F, WebSocketFrame]): Stream[F, WebSocketFrame] = if (doConcatenate) { s.mapAccumulate(None: Accumulator)(accumulateFrameState).collect { case (_, Some(f)) => f } } else s + + private def takeUntilCloseFrame(passAlongCloseFrame: Boolean)(s: Stream[F, WebSocketFrame]): Stream[F, WebSocketFrame] = + s.takeWhile( + { + case _: WebSocketFrame.Close => false + case _ => true + }, + takeFailure = passAlongCloseFrame + ) } /** A special wrapper used to override internal logic of fs2, which calls cancel() silently when internal stream failures happen, causing diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala index a0d49f546b..6183072e0f 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/NettyServerHandler.scala @@ -20,7 +20,7 @@ import sttp.tapir.server.netty.NettyResponseContent.{ ReactivePublisherNettyResponseContent, ReactiveWebSocketProcessorNettyResponseContent } -import sttp.tapir.server.netty.internal.ws.{NettyControlFrameHandler, WebSocketAutoPingHandler} +import sttp.tapir.server.netty.internal.ws.{WebSocketAutoPingHandler, WebSocketPingPongFrameHandler} import sttp.tapir.server.netty.{NettyConfig, NettyResponse, NettyServerRequest, Route} import java.util.concurrent.TimeUnit @@ -281,10 +281,9 @@ class NettyServerHandler[F[_]]( .addAfter( ServerCodecHandlerName, WebSocketControlFrameHandlerName, - new NettyControlFrameHandler( + new WebSocketPingPongFrameHandler( ignorePong = r.ignorePong, - autoPongOnPing = r.autoPongOnPing, - decodeCloseRequests = r.decodeCloseRequests + autoPongOnPing = r.autoPongOnPing ) ) r.autoPing.foreach { case (interval, pingMsg) => diff --git a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketPingPongFrameHandler.scala similarity index 53% rename from server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala rename to server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketPingPongFrameHandler.scala index 04cfb931be..1bdcd41ecd 100644 --- a/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketControlFrameHandler.scala +++ b/server/netty-server/src/main/scala/sttp/tapir/server/netty/internal/ws/WebSocketPingPongFrameHandler.scala @@ -1,13 +1,12 @@ package sttp.tapir.server.netty.internal.ws import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.handler.codec.http.websocketx.{CloseWebSocketFrame, PingWebSocketFrame, PongWebSocketFrame} +import io.netty.handler.codec.http.websocketx.{PingWebSocketFrame, PongWebSocketFrame} import sttp.tapir.server.netty.internal._ -/** Handles Ping, Pong, and Close frames for WebSockets. +/** Handles incoming Ping and Pong frames for WebSockets. */ -class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, decodeCloseRequests: Boolean) - extends ChannelInboundHandlerAdapter { +class WebSocketPingPongFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean) extends ChannelInboundHandlerAdapter { override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { msg match { @@ -23,16 +22,6 @@ class NettyControlFrameHandler(ignorePong: Boolean, autoPongOnPing: Boolean, dec } else { val _ = pong.content().release() } - case close: CloseWebSocketFrame => - if (decodeCloseRequests) { - // Passing the Close frame for further processing - val _ = ctx.fireChannelRead(close) - } else { - // Responding with Close immediately - val _ = ctx - .writeAndFlush(close) - .close() - } case other => val _ = ctx.fireChannelRead(other) } diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala index fb7ecc5d18..c3e4767f2a 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ws/OxSourceWebSocketProcessor.scala @@ -26,14 +26,15 @@ private[sync] object OxSourceWebSocketProcessor: val frame2FramePipe: OxStreams.Pipe[NettyWebSocketFrame, NettyWebSocketFrame] = (source: Source[NettyWebSocketFrame]) => { pipe( - optionallyConcatenateFrames( - source - .mapAsView { f => - val sttpFrame = nettyFrameToFrame(f) - f.release() - sttpFrame - }, - o.concatenateFragmentedFrames + optionallyConcatenateFrames(o.concatenateFragmentedFrames)( + takeUntilCloseFrame(passAlongCloseFrame = o.decodeCloseRequests)( + source + .mapAsView { f => + val sttpFrame = nettyFrameToFrame(f) + f.release() + sttpFrame + } + ) ) .mapAsView(f => o.requests.decode(f) match { @@ -59,12 +60,20 @@ private[sync] object OxSourceWebSocketProcessor: val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.INTERNAL_SERVER_ERROR, "Internal Server Error")) sub.onError(t) override def onComplete(): Unit = - val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "Bye")) + val _ = ctx.writeAndFlush(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "normal closure")) sub.onComplete() } new OxProcessor(oxDispatcher, frame2FramePipe, wrapSubscriberWithNettyCallback) - private def optionallyConcatenateFrames(s: Source[WebSocketFrame], doConcatenate: Boolean)(using Ox): Source[WebSocketFrame] = - if doConcatenate then - s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f } + private def optionallyConcatenateFrames(doConcatenate: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] = + if doConcatenate then s.mapStateful(() => None: Accumulator)(accumulateFrameState).collectAsView { case Some(f: WebSocketFrame) => f } else s + + private def takeUntilCloseFrame(passAlongCloseFrame: Boolean)(s: Source[WebSocketFrame])(using Ox): Source[WebSocketFrame] = + s.takeWhile({ + case _: WebSocketFrame.Close => false + case _ => true + }, + includeFailed = passAlongCloseFrame + ) + diff --git a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala index d0301fabbb..d5aae0d4ad 100644 --- a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala +++ b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala @@ -51,14 +51,13 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { val released: CompletableFuture[Boolean] = new CompletableFuture[Boolean]() testServer( endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)), - "closes supervision scope when client closes Web Socket without getting any responses" + "closes supervision scope when client closes Web Socket" )((_: Unit) => val pipe: OxStreams.Pipe[String, String] = in => { - val outgoing = Channel.bufferedDefault[String] releaseAfterScope { released.complete(true).discard } - outgoing + in } Right(pipe) ) { (backend, baseUri) => @@ -67,6 +66,7 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { for { _ <- ws.sendText("test1") _ <- ws.close() + _ <- ws.receiveText() closeResponse <- ws.eitherClose(ws.receiveText()) } yield closeResponse }) diff --git a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala index b56ca42041..88abed0a82 100644 --- a/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala +++ b/server/tests/src/main/scala/sttp/tapir/server/tests/ServerWebSocketTests.scala @@ -315,52 +315,55 @@ abstract class ServerWebSocketTests[F[_], S <: Streams[S], OPTIONS, ROUTE]( ) else List.empty - val frameConcatenationTests = if (frameConcatenation) List( - testServer( - endpoint.out( - webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) - .autoPing(None) - .autoPongOnPing(false) - .concatenateFragmentedFrames(true) - ), - "concatenate fragmented text frames" - )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None)) - _ <- ws.sendText("f2") - r <- ws.receiveText() - _ <- ws.close() - } yield r - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map { _.body shouldBe (Right("echo: f1f2")) } - }, - testServer( - endpoint.out( - webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams) - .autoPing(None) - .autoPongOnPing(false) - .concatenateFragmentedFrames(true) - ), - "concatenate fragmented binary frames" - )((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) => - basicRequest - .response(asWebSocket { (ws: WebSocket[IO]) => - for { - _ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None)) - _ <- ws.sendBinary("frame2-bytes".getBytes()) - r <- ws.receiveText() - _ <- ws.close() - } yield r - }) - .get(baseUri.scheme("ws")) - .send(backend) - .map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) } - } - ) else Nil + val frameConcatenationTests = + if (frameConcatenation) + List( + testServer( + endpoint.out( + webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + .concatenateFragmentedFrames(true) + ), + "concatenate fragmented text frames" + )((_: Unit) => pureResult(stringEcho.asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.send(WebSocketFrame.Text("f1", finalFragment = false, None)) + _ <- ws.sendText("f2") + r <- ws.receiveText() + _ <- ws.close() + } yield r + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { _.body shouldBe (Right("echo: f1f2")) } + }, + testServer( + endpoint.out( + webSocketBody[Array[Byte], CodecFormat.OctetStream, String, CodecFormat.TextPlain](streams) + .autoPing(None) + .autoPongOnPing(false) + .concatenateFragmentedFrames(true) + ), + "concatenate fragmented binary frames" + )((_: Unit) => pureResult(functionToPipe((bs: Array[Byte]) => s"echo: ${new String(bs)}").asRight[Unit])) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.send(WebSocketFrame.Binary("frame1-bytes;".getBytes(), finalFragment = false, None)) + _ <- ws.sendBinary("frame2-bytes".getBytes()) + r <- ws.receiveText() + _ <- ws.close() + } yield r + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { _.body shouldBe (Right("echo: frame1-bytes;frame2-bytes")) } + } + ) + else Nil val handlePongTests = if (handlePong)