diff --git a/README.md b/README.md
index 77d72c0..0f02c58 100644
--- a/README.md
+++ b/README.md
@@ -172,6 +172,33 @@ Trying to get the value from a canceled job will throw an `InterruptedException`
**You won't pay any additional cost for canceling a job**. The cancellation mechanism is based on the interruption of the virtual thread. No new structured scope is created for the cancellation mechanism.
+## Racing Jobs
+
+The library provides the `race` method to race two jobs. The `race` function returns the result of the first job that completes. The other job is canceled. The following code snippet shows how to use the `race` method:
+
+```scala 3
+val results = new ConcurrentLinkedQueue[String]()
+val actual: Int | String = structured {
+ race[Int, String](
+ {
+ delay(1.second)
+ results.add("job1")
+ throw new RuntimeException("Error")
+ }, {
+ delay(500.millis)
+ results.add("job2")
+ "42"
+ }
+ )
+}
+actual should be("42")
+results.toArray should contain theSameElementsInOrderAs List("job2")
+```
+
+If the first job completes with an exception, the `race` method waits for the second job to complete. and returns the result of the second job. If the second job completes with an exception, the `race` method throws the first exception it encountered.
+
+Each job adhere to the rules of structured concurrency. The `race` function is optimized. Every raced block creates more than one virtual thread under the hood, which should not be a problem for the Loom runtime.
+
## Contributing
If you want to contribute to the project, please do it! Any help is welcome.
diff --git a/core/src/main/scala/in/rcard/sus4s/sus4s.scala b/core/src/main/scala/in/rcard/sus4s/sus4s.scala
index fad807e..563ab63 100644
--- a/core/src/main/scala/in/rcard/sus4s/sus4s.scala
+++ b/core/src/main/scala/in/rcard/sus4s/sus4s.scala
@@ -1,6 +1,6 @@
package in.rcard.sus4s
-import java.util.concurrent.StructuredTaskScope.ShutdownOnFailure
+import java.util.concurrent.StructuredTaskScope.{ShutdownOnFailure, ShutdownOnSuccess}
import java.util.concurrent.{CompletableFuture, StructuredTaskScope}
import scala.concurrent.ExecutionException
import scala.concurrent.duration.Duration
@@ -58,34 +58,34 @@ object sus4s {
/** Cancels the job and all its children jobs. Getting the value of a cancelled job throws an
* [[InterruptedException]]. Cancellation is an idempotent operation.
- *
- *
Example
- * {{{
- * val expectedQueue = structured {
- * val queue = new ConcurrentLinkedQueue[String]()
- * val job1 = fork {
- * val innerJob = fork {
- * fork {
- * Thread.sleep(3000)
- * println("inner-inner-Job")
- * queue.add("inner-inner-Job")
- * }
- * Thread.sleep(2000)
- * println("innerJob")
- * queue.add("innerJob")
- * }
- * Thread.sleep(1000)
- * queue.add("job1")
- * }
- * val job = fork {
- * Thread.sleep(500)
- * job1.cancel()
- * queue.add("job2")
- * }
- * queue
- * }
- * expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
- * }}}
+ *
+ * Example
+ * {{{
+ * val expectedQueue = structured {
+ * val queue = new ConcurrentLinkedQueue[String]()
+ * val job1 = fork {
+ * val innerJob = fork {
+ * fork {
+ * Thread.sleep(3000)
+ * println("inner-inner-Job")
+ * queue.add("inner-inner-Job")
+ * }
+ * Thread.sleep(2000)
+ * println("innerJob")
+ * queue.add("innerJob")
+ * }
+ * Thread.sleep(1000)
+ * queue.add("job1")
+ * }
+ * val job = fork {
+ * Thread.sleep(500)
+ * job1.cancel()
+ * queue.add("job2")
+ * }
+ * queue
+ * }
+ * expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
+ * }}}
*/
def cancel(): Suspend ?=> Unit = {
// FIXME Refactor this code
@@ -199,8 +199,11 @@ object sus4s {
case Some(l) => Some(childThread :: l)
}
executingThread.complete(childThread)
- try result.complete(block)
- catch
+ try {
+ val resultValue = block
+ result.complete(resultValue)
+ resultValue
+ } catch
case _: InterruptedException =>
result.completeExceptionally(new InterruptedException("Job cancelled"))
case throwable: Throwable =>
@@ -209,7 +212,7 @@ object sus4s {
})
Job(result, executingThread)
}
-
+
/** Suspends the execution of the current thread for the given duration.
*
* @param duration
@@ -218,4 +221,56 @@ object sus4s {
def delay(duration: Duration): Suspend ?=> Unit = {
Thread.sleep(duration.toMillis)
}
+
+ /** Races two concurrent tasks and returns the result of the first one that completes. The other
+ * task is cancelled. If the first task throws an exception, it waits for the end of the second
+ * task. If both tasks throw an exception, the first one is rethrown.
+ *
+ * Each block follows the [[structured]] concurrency model. So, for each block, a new virtual
+ * thread is created more than the thread forking the block.
+ *
+ * Example
+ * {{{
+ * val results = new ConcurrentLinkedQueue[String]()
+ * val actual: Int | String = structured {
+ * race[Int, String](
+ * {
+ * delay(1.second)
+ * results.add("job1")
+ * throw new RuntimeException("Error")
+ * }, {
+ * delay(500.millis)
+ * results.add("job2")
+ * "42"
+ * }
+ * )
+ * }
+ * actual should be("42")
+ * results.toArray should contain theSameElementsInOrderAs List("job2")
+ * }}}
+ *
+ * @param firstBlock
+ * First block to race
+ * @param secondBlock
+ * Second block to race
+ * @tparam A
+ * Result type of the first block
+ * @tparam B
+ * Result type of the second block
+ * @return
+ * The result of the first block that completes
+ */
+ def race[A, B](firstBlock: Suspend ?=> A, secondBlock: Suspend ?=> B): Suspend ?=> A | B = {
+ val loomScope = new ShutdownOnSuccess[A | B]()
+ given suspended: Suspend = SuspendScope(loomScope.asInstanceOf[StructuredTaskScope[Any]])
+ try {
+ loomScope.fork(() => { structured { firstBlock } })
+ loomScope.fork(() => { structured { secondBlock } })
+
+ loomScope.join()
+ loomScope.result(identity)
+ } finally {
+ loomScope.close()
+ }
+ }
}
diff --git a/core/src/test/scala/CancelSpec.scala b/core/src/test/scala/CancelSpec.scala
new file mode 100644
index 0000000..82d9334
--- /dev/null
+++ b/core/src/test/scala/CancelSpec.scala
@@ -0,0 +1,136 @@
+import in.rcard.sus4s.sus4s
+import in.rcard.sus4s.sus4s.{delay, fork, structured}
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+
+import java.util.concurrent.ConcurrentLinkedQueue
+import scala.concurrent.duration.*
+
+class CancelSpec extends AnyFlatSpec with Matchers {
+ "cancellation" should "cancel at the first suspending point" in {
+ val expectedQueue = structured {
+ val queue = new ConcurrentLinkedQueue[String]()
+ val cancellable = fork {
+ delay(2.seconds)
+ queue.add("cancellable")
+ }
+ val job = fork {
+ delay(500.millis)
+ cancellable.cancel()
+ queue.add("job2")
+ }
+ queue
+ }
+ expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
+ }
+
+ it should "not throw an exception if joined" in {
+
+ val expectedQueue = structured {
+ val queue = new ConcurrentLinkedQueue[String]()
+ val cancellable = fork {
+ delay(2.seconds)
+ queue.add("cancellable")
+ }
+ val job = fork {
+ delay(500.millis)
+ cancellable.cancel()
+ queue.add("job2")
+ }
+ cancellable.join()
+ queue
+ }
+ expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
+ }
+
+ it should "not cancel parent job" in {
+
+ val expectedQueue = structured {
+ val queue = new ConcurrentLinkedQueue[String]()
+ val job1 = fork {
+ val innerCancellableJob = fork {
+ delay(2.seconds)
+ queue.add("cancellable")
+ }
+ delay(1.second)
+ innerCancellableJob.cancel()
+ queue.add("job1")
+ }
+ val job = fork {
+ delay(500.millis)
+ queue.add("job2")
+ }
+ queue
+ }
+ expectedQueue.toArray should contain theSameElementsInOrderAs List("job2", "job1")
+ }
+
+ it should "cancel children jobs" in {
+ val expectedQueue = structured {
+ val queue = new ConcurrentLinkedQueue[String]()
+ val job1 = fork {
+ val innerJob = fork {
+ fork {
+ delay(3.seconds)
+ println("inner-inner-Job")
+ queue.add("inner-inner-Job")
+ }
+
+ delay(2.seconds)
+ println("innerJob")
+ queue.add("innerJob")
+ }
+ delay(1.second)
+ queue.add("job1")
+ }
+ val job = fork {
+ delay(500.millis)
+ job1.cancel()
+ queue.add("job2")
+ }
+ queue
+ }
+ expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
+ }
+
+ it should "not throw any exception when joining a cancelled job" in {
+ val expected = structured {
+ val cancellable = fork {
+ delay(2.seconds)
+ }
+ delay(500.millis)
+ cancellable.cancel()
+ cancellable.join()
+ 42
+ }
+
+ expected shouldBe 42
+ }
+
+ it should "not throw any exception if a job is canceled twice" in {
+ val expected = structured {
+ val cancellable = fork {
+ delay(2.seconds)
+ }
+ delay(500.millis)
+ cancellable.cancel()
+ cancellable.cancel()
+ 42
+ }
+
+ expected shouldBe 42
+ }
+
+ it should "throw an exception when asking for the value of a cancelled job" in {
+ assertThrows[InterruptedException] {
+ structured {
+ val cancellable = fork {
+ delay(2.seconds)
+ }
+ delay(500.millis)
+ cancellable.cancel()
+ cancellable.value
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/RaceSpec.scala b/core/src/test/scala/RaceSpec.scala
new file mode 100644
index 0000000..080723e
--- /dev/null
+++ b/core/src/test/scala/RaceSpec.scala
@@ -0,0 +1,115 @@
+import in.rcard.sus4s.sus4s.{delay, fork, race, structured}
+import org.scalatest.TryValues.*
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+
+import java.util.concurrent.ConcurrentLinkedQueue
+import scala.concurrent.duration.*
+import scala.util.Try
+
+class RaceSpec extends AnyFlatSpec with Matchers {
+ "Racing two functions" should "return the result of the first one that completes and cancel the execution of the other" in {
+ val results = new ConcurrentLinkedQueue[String]()
+ val actual: Int | String = structured {
+ race[Int, String](
+ {
+ delay(1.second)
+ results.add("job1")
+ throw new RuntimeException("Error")
+ }, {
+ delay(500.millis)
+ results.add("job2")
+ "42"
+ }
+ )
+ }
+
+ actual should be("42")
+ results.toArray should contain theSameElementsInOrderAs List("job2")
+ }
+
+ it should "return the result of the second one if the first one throws an exception" in {
+ val results = new ConcurrentLinkedQueue[String]()
+ val actual: Int | String = structured {
+ race(
+ {
+ delay(1.second)
+ results.add("job1")
+ 42
+ }, {
+ delay(500.millis)
+ results.add("job2")
+ throw new RuntimeException("Error")
+ }
+ )
+ }
+
+ actual should be(42)
+ results.toArray should contain theSameElementsInOrderAs List("job2", "job1")
+ }
+
+ it should "honor the structural concurrency and wait for all the jobs to complete" in {
+ val results = new ConcurrentLinkedQueue[String]()
+ val actual: Int | String = structured {
+ race(
+ {
+ val job1 = fork {
+ fork {
+ delay(2.second)
+ results.add("job3")
+ }
+ delay(1.second)
+ results.add("job1")
+ }
+ 42
+ }, {
+ delay(500.millis)
+ throw new RuntimeException("Error")
+ }
+ )
+ }
+
+ actual should be(42)
+ results.toArray should contain theSameElementsInOrderAs List("job1", "job3")
+ }
+
+ it should "throw the exception thrown by the first function both throw an exception" in {
+ val expectedResult = Try {
+ structured {
+ race(
+ {
+ delay(1.second)
+ throw new RuntimeException("Error in job1")
+ }, {
+ delay(500.millis)
+ throw new RuntimeException("Error in job2")
+ }
+ )
+ }
+ }
+
+ expectedResult.failure.exception shouldBe a[RuntimeException]
+ expectedResult.failure.exception.getMessage shouldBe "Error in job2"
+ }
+
+ it should "honor the structural concurrency and return the value of the second function if the first threw an exception" in {
+ val actual: Int | String = structured {
+ race(
+ {
+ val job1 = fork {
+ delay(500.millis)
+ println("job1")
+ throw new RuntimeException("Error in job1")
+ }
+ 42
+ }, {
+ delay(1.second)
+ println("job2")
+ "42"
+ }
+ )
+ }
+
+ actual should be("42")
+ }
+}
diff --git a/core/src/test/scala/StructuredSpec.scala b/core/src/test/scala/StructuredSpec.scala
index 380abbf..ca012b4 100644
--- a/core/src/test/scala/StructuredSpec.scala
+++ b/core/src/test/scala/StructuredSpec.scala
@@ -56,6 +56,34 @@ class StructuredSpec extends AnyFlatSpec with Matchers {
results.toArray should contain theSameElementsInOrderAs List("job3", "job2")
}
+ it should "stop the execution if a child job throws an exception" in {
+ val results = new ConcurrentLinkedQueue[String]()
+ val tryResult = Try {
+ structured {
+ val job1 = fork {
+ delay(1.second)
+ results.add("job1")
+ }
+ val job2 = fork {
+ delay(500.millis)
+ results.add("job2")
+ fork {
+ delay(100.millis)
+ throw new RuntimeException("Error")
+ }
+ }
+ val job3 = fork {
+ delay(100.millis)
+ results.add("job3")
+ }
+ }
+ }
+
+ tryResult.failure.exception shouldBe a[RuntimeException]
+ tryResult.failure.exception.getMessage shouldBe "Error"
+ results.toArray should contain theSameElementsInOrderAs List("job3", "job2")
+ }
+
it should "stop the execution if the block throws an exception" in {
val results = new ConcurrentLinkedQueue[String]()
val tryResult = Try {
@@ -120,131 +148,4 @@ class StructuredSpec extends AnyFlatSpec with Matchers {
results.toArray should contain theSameElementsInOrderAs List("3", "2", "1")
}
-
- "cancellation" should "cancel at the first suspending point" in {
- val expectedQueue = structured {
- val queue = new ConcurrentLinkedQueue[String]()
- val cancellable = fork {
- delay(2.seconds)
- queue.add("cancellable")
- }
- val job = fork {
- delay(500.millis)
- cancellable.cancel()
- queue.add("job2")
- }
- queue
- }
- expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
- }
-
- it should "not throw an exception if joined" in {
-
- val expectedQueue = structured {
- val queue = new ConcurrentLinkedQueue[String]()
- val cancellable = fork {
- delay(2.seconds)
- queue.add("cancellable")
- }
- val job = fork {
- delay(500.millis)
- cancellable.cancel()
- queue.add("job2")
- }
- cancellable.join()
- queue
- }
- expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
- }
-
- it should "not cancel parent job" in {
-
- val expectedQueue = structured {
- val queue = new ConcurrentLinkedQueue[String]()
- val job1 = fork {
- val innerCancellableJob = fork {
- delay(2.seconds)
- queue.add("cancellable")
- }
- delay(1.second)
- innerCancellableJob.cancel()
- queue.add("job1")
- }
- val job = fork {
- delay(500.millis)
- queue.add("job2")
- }
- queue
- }
- expectedQueue.toArray should contain theSameElementsInOrderAs List("job2", "job1")
- }
-
- it should "cancel children jobs" in {
- val expectedQueue = structured {
- val queue = new ConcurrentLinkedQueue[String]()
- val job1 = fork {
- val innerJob = fork {
- fork {
- delay(3.seconds)
- println("inner-inner-Job")
- queue.add("inner-inner-Job")
- }
-
- delay(2.seconds)
- println("innerJob")
- queue.add("innerJob")
- }
- delay(1.second)
- queue.add("job1")
- }
- val job = fork {
- delay(500.millis)
- job1.cancel()
- queue.add("job2")
- }
- queue
- }
- expectedQueue.toArray should contain theSameElementsInOrderAs List("job2")
- }
-
- it should "not throw any exception when joining a cancelled job" in {
- val expected = structured {
- val cancellable = fork {
- delay(2.seconds)
- }
- delay(500.millis)
- cancellable.cancel()
- cancellable.join()
- 42
- }
-
- expected shouldBe 42
- }
-
- it should "not throw any exception if a job is canceled twice" in {
- val expected = structured {
- val cancellable = fork {
- delay(2.seconds)
- }
- delay(500.millis)
- cancellable.cancel()
- cancellable.cancel()
- 42
- }
-
- expected shouldBe 42
- }
-
- it should "throw an exception when asking for the value of a cancelled job" in {
- assertThrows[InterruptedException] {
- structured {
- val cancellable = fork {
- delay(2.seconds)
- }
- delay(500.millis)
- cancellable.cancel()
- cancellable.value
- }
- }
- }
}