Skip to content

Commit

Permalink
[bugfix] Ensure cancelation of WS pipeline forks (#3755)
Browse files Browse the repository at this point in the history
  • Loading branch information
kciesielski authored May 9, 2024
1 parent 8bcc718 commit 585434d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ package sttp.tapir.server.netty.sync.internal.ox
import ox.*
import ox.channels.Actor

import scala.util.control.NonFatal
import scala.concurrent.Future
import scala.concurrent.Promise

/** A dispatcher that can start arbitrary forks. Useful when one needs to start an asynchronous task from a thread outside of an Ox scope.
* Normally Ox doesn't allow to start forks from other threads, for example in callbacks of external libraries. If you create an
* OxDispatcher inside a scope and pass it for potential handling on another thread, that thread can call
Expand All @@ -20,12 +24,17 @@ import ox.channels.Actor
*/
private[sync] class OxDispatcher()(using ox: Ox):
private class Runner:
def runAsync(thunk: Ox ?=> Unit, onError: Throwable => Unit): Unit =
fork {
try supervised(thunk)
catch case e => onError(e)
}.discard
def runAsync(thunk: Ox ?=> Unit, onError: Throwable => Unit, forkPromise: Promise[CancellableFork[Unit]]): Unit =
forkPromise
.success(forkCancellable {
try supervised(thunk)
catch case NonFatal(e) => onError(e)
})
.discard

private val actor = Actor.create(new Runner)

def runAsync(thunk: Ox ?=> Unit)(onError: Throwable => Unit): Unit = actor.tell(_.runAsync(thunk, onError))
def runAsync(thunk: Ox ?=> Unit)(onError: Throwable => Unit): Future[CancellableFork[Unit]] =
val forkPromise = Promise[CancellableFork[Unit]]()
actor.tell(_.runAsync(thunk, onError, forkPromise))
forkPromise.future
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package sttp.tapir.server.netty.sync.internal.reactivestreams

import org.reactivestreams.Subscriber
import org.reactivestreams.{Processor, Subscriber, Subscription}
import org.slf4j.LoggerFactory
import ox.*
import ox.channels.*
import org.reactivestreams.Subscription
import org.reactivestreams.Processor
import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher
import sttp.tapir.server.netty.sync.OxStreams
import sttp.tapir.server.netty.sync.internal.ox.OxDispatcher

import scala.concurrent.duration.*
import scala.concurrent.{Await, Future}
import scala.util.control.NonFatal

/** A reactive Processor, which is both a Publisher and a Subscriber
*
Expand All @@ -23,15 +26,20 @@ private[sync] class OxProcessor[A, B](
pipeline: OxStreams.Pipe[A, B],
wrapSubscriber: Subscriber[? >: B] => Subscriber[? >: B]
) extends Processor[A, B]:
private val logger = LoggerFactory.getLogger(getClass.getName)
// Incoming requests are read from this subscription into an Ox Channel[A]
@volatile private var requestsSubscription: Subscription = _
// An internal channel for holding incoming requests (`A`), will be wrapped with user's pipeline to produce responses (`B`)
private val channel = Channel.buffered[A](1)

private val pipelineCancelationTimeout = 5.seconds
@volatile private var pipelineForkFuture: Future[CancellableFork[Unit]] = _

override def onError(reason: Throwable): Unit =
// As per rule 2.13, we need to throw a `java.lang.NullPointerException` if the `Throwable` is `null`
if reason == null then throw null
channel.errorOrClosed(reason).discard
cancelPipelineFork()

override def onNext(a: A): Unit =
if a == null then throw new NullPointerException("Element cannot be null") // Rule 2.13
Expand All @@ -51,11 +59,12 @@ private[sync] class OxProcessor[A, B](

override def onComplete(): Unit =
channel.doneOrClosed().discard
cancelPipelineFork()

override def subscribe(subscriber: Subscriber[? >: B]): Unit =
if subscriber == null then throw new NullPointerException("Subscriber cannot be null")
val wrappedSubscriber = wrapSubscriber(subscriber)
oxDispatcher.runAsync {
pipelineForkFuture = oxDispatcher.runAsync {
val outgoingResponses: Source[B] = pipeline((channel: Source[A]).mapAsView { e =>
requestsSubscription.request(1)
e
Expand All @@ -68,6 +77,23 @@ private[sync] class OxProcessor[A, B](
onError(error)
}

private def cancelPipelineFork(): Unit =
if (pipelineForkFuture != null) try {
val pipelineFork = Await.result(pipelineForkFuture, pipelineCancelationTimeout)
oxDispatcher.runAsync {
race(
{
ox.sleep(pipelineCancelationTimeout)
logger.error(s"Pipeline fork cancelation did not complete in time ($pipelineCancelationTimeout).")
},
pipelineFork.cancel()
) match {
case Left(NonFatal(e)) => logger.error("Error when canceling pipeline fork", e)
case _ => ()
}
} { e => logger.error("Error when canceling pipeline fork", e) }.discard
} catch case NonFatal(e) => logger.error("Error when waiting for pipeline fork to start", e)

private def cancelSubscription() =
if requestsSubscription != null then
try requestsSubscription.cancel()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
package sttp.tapir.server.netty.sync

import cats.data.NonEmptyList
import cats.effect.unsafe.implicits.global
import cats.effect.IO
import cats.effect.unsafe.implicits.global
import io.netty.channel.nio.NioEventLoopGroup
import org.scalactic.source.Position
import org.scalatest.BeforeAndAfterAll
import org.scalatest.compatible.Assertion
import org.scalatest.funsuite.AsyncFunSuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.matchers.should.Matchers.*
import org.slf4j.LoggerFactory
import ox.*
import ox.channels.Source
import ox.channels.*
import sttp.capabilities.WebSockets
import sttp.capabilities.fs2.Fs2Streams
import sttp.client3.*
import sttp.model.*
import sttp.tapir.PublicEndpoint
import sttp.tapir.*
import sttp.tapir.server.ServerEndpoint
import sttp.tapir.server.tests.*
import sttp.tapir.tests.*
import sttp.ws.{WebSocket, WebSocketFrame}

import scala.concurrent.duration.FiniteDuration
import java.util.concurrent.{CompletableFuture, TimeUnit}
import scala.concurrent.Future
import scala.concurrent.duration.FiniteDuration

class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {

Expand All @@ -41,6 +44,39 @@ class NettySyncServerTest extends AsyncFunSuite with BeforeAndAfterAll {
new ServerWebSocketTests(createServerTest, OxStreams, autoPing = true, failingPipe = true, handlePong = true) {
override def functionToPipe[A, B](f: A => B): OxStreams.Pipe[A, B] = ox ?=> in => in.map(f)
override def emptyPipe[A, B]: OxStreams.Pipe[A, B] = _ => Source.empty

import createServerTest._
override def tests(): List[Test] = super.tests() ++ List({
val released: CompletableFuture[Boolean] = new CompletableFuture[Boolean]()
testServer(
endpoint.out(webSocketBody[String, CodecFormat.TextPlain, String, CodecFormat.TextPlain].apply(streams)),
"closes supervision scope when client closes Web Socket without getting any responses"
)((_: Unit) =>
val pipe: OxStreams.Pipe[String, String] = in => {
val outgoing = Channel.bufferedDefault[String]
releaseAfterScope {
released.complete(true).discard
}
outgoing
}
Right(pipe)
) { (backend, baseUri) =>
basicRequest
.response(asWebSocket { (ws: WebSocket[IO]) =>
for {
_ <- ws.sendText("test1")
_ <- ws.close()
closeResponse <- ws.eitherClose(ws.receiveText())
} yield closeResponse
})
.get(baseUri.scheme("ws"))
.send(backend)
.map { r =>
r.body.value shouldBe Left(WebSocketFrame.Close(1000, "normal closure"))
released.get(15, TimeUnit.SECONDS) shouldBe true
}
}
})
}.tests()

tests.foreach { t =>
Expand Down

0 comments on commit 585434d

Please sign in to comment.