diff --git a/core/src/main/scala/in/rcard/sus4s/sus4s.scala b/core/src/main/scala/in/rcard/sus4s/sus4s.scala
index a80a67e..cc2f6c4 100644
--- a/core/src/main/scala/in/rcard/sus4s/sus4s.scala
+++ b/core/src/main/scala/in/rcard/sus4s/sus4s.scala
@@ -1,6 +1,8 @@
package in.rcard.sus4s
-import java.util.concurrent.{CompletableFuture, StructuredTaskScope}
+import java.util.concurrent.StructuredTaskScope.Subtask
+import java.util.concurrent.{CompletableFuture, ExecutionException, StructuredTaskScope}
+import scala.compiletime.uninitialized
object sus4s {
@@ -21,7 +23,18 @@ object sus4s {
* The type of the value returned by the job
*/
class Job[A] private[sus4s] (private val cf: CompletableFuture[A]) {
- def value: A = cf.join()
+ def value: A = cf.get()
+ }
+
+ class CancellableJob[A] private[sus4s] (val cf: CompletableFuture[A]) extends Job[A](cf) {
+ var scope: StructuredTaskScope[Any] = uninitialized
+ def cancel(): Unit = {
+ cf.completeExceptionally(new InterruptedException("Job cancelled"))
+ try
+ cf.get()
+ catch
+ case e =>
+ }
}
/** Executes a block of code applying structured concurrency to the contained suspendable tasks
@@ -84,8 +97,8 @@ object sus4s {
/** Forks a new concurrent task executing the given block of code and returning a [[Job]] that
* completes with the value of type `A`. The task is executed in a new Virtual Thread using the
- * given [[Suspend]] context. The job is transparent to any exception thrown by the `block`, which
- * means it rethrows the exception.
+ * given [[Suspend]] context. The job is transparent to any exception thrown by the `block`,
+ * which means it rethrows the exception.
*
*
Example
* {{{
@@ -119,4 +132,50 @@ object sus4s {
})
Job(result)
}
+
+ def forkCancellable[A](block: Suspend ?=> A): Suspend ?=> CancellableJob[A] = {
+ val result = new CompletableFuture[A]()
+ val innerResult = new CompletableFuture[A]()
+ val cancellableJob = CancellableJob(innerResult)
+ summon[Suspend].scope.fork(() => {
+ val innerScope = new StructuredTaskScope.ShutdownOnFailure()
+ cancellableJob.scope = innerScope
+
+ given innerSuspended: Suspend = new Suspend {
+ override val scope: StructuredTaskScope[Any] = innerScope
+ }
+ try {
+ val subtask = innerScope.fork(() => {
+ val result = block(using innerSuspended)
+ innerResult.complete(result)
+ result
+ })
+ try
+ innerResult.get()
+ catch
+ case e: Throwable =>
+ innerScope.shutdown()
+ innerScope.join().throwIfFailed(identity)
+ if (subtask.state() == Subtask.State.UNAVAILABLE) {
+ // TODO Handle all cases
+ result.completeExceptionally(new InterruptedException("Job cancelled"))
+ } else {
+ result.complete(subtask.get())
+ }
+ } catch
+ case exex: ExecutionException =>
+ exex.getCause match
+ case ie: InterruptedException => innerScope.shutdown()
+ case e: Throwable =>
+ result.completeExceptionally(e)
+ throw e
+ case e: Throwable =>
+ result.completeExceptionally(e)
+ throw e
+ finally {
+ innerScope.close()
+ }
+ })
+ cancellableJob
+ }
}
diff --git a/core/src/test/scala/StructuredSpec.scala b/core/src/test/scala/StructuredSpec.scala
index 6fc15c5..934d638 100644
--- a/core/src/test/scala/StructuredSpec.scala
+++ b/core/src/test/scala/StructuredSpec.scala
@@ -1,5 +1,5 @@
import in.rcard.sus4s.sus4s
-import in.rcard.sus4s.sus4s.{fork, structured}
+import in.rcard.sus4s.sus4s.{fork, forkCancellable, structured}
import org.scalatest.TryValues.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
@@ -99,4 +99,26 @@ class StructuredSpec extends AnyFlatSpec with Matchers {
queue.toArray should contain theSameElementsInOrderAs List("job2", "job1")
result shouldBe 85
}
+
+ it should "cancel at the first suspending point" in {
+ val queue = new ConcurrentLinkedQueue[String]()
+ val result = structured {
+ val cancellable = forkCancellable {
+ while (true) {
+ Thread.sleep(2000)
+ println("cancellable job")
+ queue.add("cancellable")
+ }
+ }
+ val job = fork {
+ cancellable.cancel()
+ queue.add("job2")
+ 43
+ }
+ job.value
+ }
+
+ queue.toArray should contain theSameElementsInOrderAs List("job2")
+ result shouldBe 43
+ }
}