From 585434d022901aade1b60b3a98372303e83ea6c6 Mon Sep 17 00:00:00 2001 From: Krzysztof Ciesielski Date: Thu, 9 May 2024 09:20:46 +0200 Subject: [PATCH] [bugfix] Ensure cancelation of WS pipeline forks (#3755) --- .../netty/sync/internal/ox/OxDispatcher.scala | 21 ++++++--- .../reactivestreams/OxProcessor.scala | 36 +++++++++++++-- .../netty/sync/NettySyncServerTest.scala | 46 +++++++++++++++++-- 3 files changed, 87 insertions(+), 16 deletions(-) diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ox/OxDispatcher.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ox/OxDispatcher.scala index bc03173f76..9a9ad8207b 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ox/OxDispatcher.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/ox/OxDispatcher.scala @@ -3,6 +3,10 @@ package sttp.tapir.server.netty.sync.internal.ox import ox.* import ox.channels.Actor +import scala.util.control.NonFatal +import scala.concurrent.Future +import scala.concurrent.Promise + /** A dispatcher that can start arbitrary forks. Useful when one needs to start an asynchronous task from a thread outside of an Ox scope. * Normally Ox doesn't allow to start forks from other threads, for example in callbacks of external libraries. If you create an * OxDispatcher inside a scope and pass it for potential handling on another thread, that thread can call @@ -20,12 +24,17 @@ import ox.channels.Actor */ private[sync] class OxDispatcher()(using ox: Ox): private class Runner: - def runAsync(thunk: Ox ?=> Unit, onError: Throwable => Unit): Unit = - fork { - try supervised(thunk) - catch case e => onError(e) - }.discard + def runAsync(thunk: Ox ?=> Unit, onError: Throwable => Unit, forkPromise: Promise[CancellableFork[Unit]]): Unit = + forkPromise + .success(forkCancellable { + try supervised(thunk) + catch case NonFatal(e) => onError(e) + }) + .discard private val actor = Actor.create(new Runner) - def runAsync(thunk: Ox ?=> Unit)(onError: Throwable => Unit): Unit = actor.tell(_.runAsync(thunk, onError)) + def runAsync(thunk: Ox ?=> Unit)(onError: Throwable => Unit): Future[CancellableFork[Unit]] = + val forkPromise = Promise[CancellableFork[Unit]]() + actor.tell(_.runAsync(thunk, onError, forkPromise)) + forkPromise.future diff --git a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/reactivestreams/OxProcessor.scala b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/reactivestreams/OxProcessor.scala index b2e928c561..35c7d3ffc6 100644 --- a/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/reactivestreams/OxProcessor.scala +++ b/server/netty-server/sync/src/main/scala/sttp/tapir/server/netty/sync/internal/reactivestreams/OxProcessor.scala @@ -1,12 +1,15 @@ package sttp.tapir.server.netty.sync.internal.reactivestreams -import org.reactivestreams.Subscriber +import org.reactivestreams.{Processor, Subscriber, Subscription} +import org.slf4j.LoggerFactory import ox.* import ox.channels.* -import org.reactivestreams.Subscription -import org.reactivestreams.Processor -import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher import sttp.tapir.server.netty.sync.OxStreams +import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher + +import scala.concurrent.duration.* +import scala.concurrent.{Await, Future} +import scala.util.control.NonFatal /** A reactive Processor, which is both a Publisher and a Subscriber * @@ -23,15 +26,20 @@ private[sync] class OxProcessor[A, B]( pipeline: OxStreams.Pipe[A, B], wrapSubscriber: Subscriber[? >: B] => Subscriber[? >: B] ) extends Processor[A, B]: + private val logger = LoggerFactory.getLogger(getClass.getName) // Incoming requests are read from this subscription into an Ox Channel[A] @volatile private var requestsSubscription: Subscription = _ // An internal channel for holding incoming requests (`A`), will be wrapped with user's pipeline to produce responses (`B`) private val channel = Channel.buffered[A](1) + private val pipelineCancelationTimeout = 5.seconds + @volatile private var pipelineForkFuture: Future[CancellableFork[Unit]] = _ + override def onError(reason: Throwable): Unit = // As per rule 2.13, we need to throw a `java.lang.NullPointerException` if the `Throwable` is `null` if reason == null then throw null channel.errorOrClosed(reason).discard + cancelPipelineFork() override def onNext(a: A): Unit = if a == null then throw new NullPointerException("Element cannot be null") // Rule 2.13 @@ -51,11 +59,12 @@ private[sync] class OxProcessor[A, B]( override def onComplete(): Unit = channel.doneOrClosed().discard + cancelPipelineFork() override def subscribe(subscriber: Subscriber[? >: B]): Unit = if subscriber == null then throw new NullPointerException("Subscriber cannot be null") val wrappedSubscriber = wrapSubscriber(subscriber) - oxDispatcher.runAsync { + pipelineForkFuture = oxDispatcher.runAsync { val outgoingResponses: Source[B] = pipeline((channel: Source[A]).mapAsView { e => requestsSubscription.request(1) e @@ -68,6 +77,23 @@ private[sync] class OxProcessor[A, B]( onError(error) } + private def cancelPipelineFork(): Unit = + if (pipelineForkFuture != null) try { + val pipelineFork = Await.result(pipelineForkFuture, pipelineCancelationTimeout) + oxDispatcher.runAsync { + race( + { + ox.sleep(pipelineCancelationTimeout) + logger.error(s"Pipeline fork cancelation did not complete in time ($pipelineCancelationTimeout).") + }, + pipelineFork.cancel() + ) match { + case Left(NonFatal(e)) => logger.error("Error when canceling pipeline fork", e) + case _ => () + } + } { e => logger.error("Error when canceling pipeline fork", e) }.discard + } catch case NonFatal(e) => logger.error("Error when waiting for pipeline fork to start", e) + private def cancelSubscription() = if requestsSubscription != null then try requestsSubscription.cancel() diff --git a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala index c7bd39db8c..b37d73ac1f 100644 --- a/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala +++ b/server/netty-server/sync/src/test/scala/sttp/tapir/server/netty/sync/NettySyncServerTest.scala @@ -1,27 +1,30 @@ package sttp.tapir.server.netty.sync import cats.data.NonEmptyList -import cats.effect.unsafe.implicits.global import cats.effect.IO +import cats.effect.unsafe.implicits.global import io.netty.channel.nio.NioEventLoopGroup import org.scalactic.source.Position +import org.scalatest.BeforeAndAfterAll import org.scalatest.compatible.Assertion import org.scalatest.funsuite.AsyncFunSuite -import org.scalatest.BeforeAndAfterAll +import org.scalatest.matchers.should.Matchers.* import org.slf4j.LoggerFactory import ox.* -import ox.channels.Source +import ox.channels.* import sttp.capabilities.WebSockets import sttp.capabilities.fs2.Fs2Streams import sttp.client3.* import sttp.model.* -import sttp.tapir.PublicEndpoint +import sttp.tapir.* import sttp.tapir.server.ServerEndpoint import sttp.tapir.server.tests.* import sttp.tapir.tests.* +import sttp.ws.{WebSocket, WebSocketFrame} -import scala.concurrent.duration.FiniteDuration +import java.util.concurrent.{CompletableFuture, TimeUnit} import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { @@ -41,6 +44,39 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll { new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) { override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = ox ?=> in => in.map(f) override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Source.empty + + import createServerTest._ + override def tests(): List[Test] = super.tests() ++ List({ + val released: CompletableFuture[Boolean] = new CompletableFuture[Boolean]() + testServer( + endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)), + "closes supervision scope when client closes Web Socket without getting any responses" + )((_: Unit) => + val pipe: OxStreams.Pipe[String, String] = in => { + val outgoing = Channel.bufferedDefault[String] + releaseAfterScope { + released.complete(true).discard + } + outgoing + } + Right(pipe) + ) { (backend, baseUri) => + basicRequest + .response(asWebSocket { (ws: WebSocket[IO]) => + for { + _ <- ws.sendText("test1") + _ <- ws.close() + closeResponse <- ws.eitherClose(ws.receiveText()) + } yield closeResponse + }) + .get(baseUri.scheme("ws")) + .send(backend) + .map { r => + r.body.value shouldBe Left(WebSocketFrame.Close(1000, "normal closure")) + released.get(15, TimeUnit.SECONDS) shouldBe true + } + } + }) }.tests() tests.foreach { t =>