diff --git a/build.sbt b/build.sbt index 204dd831..60bacf45 100644 --- a/build.sbt +++ b/build.sbt @@ -2,6 +2,8 @@ import com.softwaremill.SbtSoftwareMillCommon.commonSmlBuildSettings import com.softwaremill.Publish.{ossPublishSettings, updateDocs} import com.softwaremill.UpdateVersionInDocs +Global / cancelable := true + lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq( organization := "com.softwaremill.ox", scalaVersion := "3.3.3", @@ -50,7 +52,8 @@ lazy val core: Project = (project in file("core")) scalaTest ), // Check IO usage in core - useRequireIOPlugin + useRequireIOPlugin, + Test / fork := true ) lazy val plugin: Project = (project in file("plugin")) diff --git a/core/src/main/scala/ox/fork.scala b/core/src/main/scala/ox/fork.scala index fe4d8a1f..f6ebd8bf 100644 --- a/core/src/main/scala/ox/fork.scala +++ b/core/src/main/scala/ox/fork.scala @@ -122,6 +122,7 @@ def forkAll[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] = val forks = fs.map(f => fork(f())) new Fork[Seq[T]]: override def join(): Seq[T] = forks.map(_.join()) + override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = forks.exists(_.wasInterruptedWith(ie)) /** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]], [[supervisedError]] or * [[unsupervised]] block completes, and which can be cancelled on-demand. @@ -177,8 +178,13 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] = if !started.getAndSet(true) then result.completeExceptionally(new InterruptedException("fork was cancelled before it started")).discard + override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = + result.isCompletedExceptionally && (result.exceptionNow() eq ie) + private def newForkUsingResult[T](result: CompletableFuture[T]): Fork[T] = new Fork[T]: override def join(): T = unwrapExecutionException(result.get()) + override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = + result.isCompletedExceptionally && (result.exceptionNow() eq ie) private[ox] inline def unwrapExecutionException[T](f: => T): T = try f @@ -208,16 +214,23 @@ trait Fork[T]: def joinEither(): Either[Throwable, T] = try Right(join()) catch - // normally IE is fatal, but here it was meant to cancel the fork, not the joining parent, hence we catch it - case e: InterruptedException => Left(e) + // normally IE is fatal, but here it could have meant that the fork was cancelled, hence we catch it + // we do discern between the fork and the current thread being cancelled and rethrow if it's us who's getting the axe + case e: InterruptedException => if wasInterruptedWith(e) then Left(e) else throw e case NonFatal(e) => Left(e) + private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean + object Fork: /** A dummy pretending to represent a fork which successfully completed with the given value. */ - def successful[T](value: T): Fork[T] = () => value + def successful[T](value: T): Fork[T] = new Fork[T]: + override def join(): T = value + override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = false /** A dummy pretending to represent a fork which failed with the given exception. */ - def failed[T](e: Throwable): Fork[T] = () => throw e + def failed[T](e: Throwable): Fork[T] = new Fork[T]: + override def join(): T = throw e + override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = e eq ie /** A fork started using [[forkCancellable]], backed by a (virtual) thread. */ trait CancellableFork[T] extends Fork[T]: diff --git a/core/src/test/scala/ox/SupervisedTest.scala b/core/src/test/scala/ox/SupervisedTest.scala index 4af73ab5..bb86c23e 100644 --- a/core/src/test/scala/ox/SupervisedTest.scala +++ b/core/src/test/scala/ox/SupervisedTest.scala @@ -123,4 +123,31 @@ class SupervisedTest extends AnyFlatSpec with Matchers { trail.add("done") trail.get shouldBe Vector("b", "a", "done") } + + it should "handle interruption of multiple forks with `joinEither` correctly" in { + val e = intercept[Exception] { + supervised { + def computation(withException: Option[String]): Int = { + withException match + case None => 1 + case Some(value) => + throw new Exception(value) + } + + val fork1 = fork: + computation(withException = None) + val fork2 = fork: + computation(withException = Some("Oh no!")) + val fork3 = fork: + computation(withException = Some("Oh well..")) + + fork1.joinEither() // 1 + fork2.joinEither() // 2 + fork3.joinEither() // 3 + } + } + + e.getMessage should startWith("Oh") + } + }