diff --git a/core/src/main/scala/in/rcard/sus4s/sus4s.scala b/core/src/main/scala/in/rcard/sus4s/sus4s.scala index a80a67e..cc2f6c4 100644 --- a/core/src/main/scala/in/rcard/sus4s/sus4s.scala +++ b/core/src/main/scala/in/rcard/sus4s/sus4s.scala @@ -1,6 +1,8 @@ package in.rcard.sus4s -import java.util.concurrent.{CompletableFuture, StructuredTaskScope} +import java.util.concurrent.StructuredTaskScope.Subtask +import java.util.concurrent.{CompletableFuture, ExecutionException, StructuredTaskScope} +import scala.compiletime.uninitialized object sus4s { @@ -21,7 +23,18 @@ object sus4s { * The type of the value returned by the job */ class Job[A] private[sus4s] (private val cf: CompletableFuture[A]) { - def value: A = cf.join() + def value: A = cf.get() + } + + class CancellableJob[A] private[sus4s] (val cf: CompletableFuture[A]) extends Job[A](cf) { + var scope: StructuredTaskScope[Any] = uninitialized + def cancel(): Unit = { + cf.completeExceptionally(new InterruptedException("Job cancelled")) + try + cf.get() + catch + case e => + } } /** Executes a block of code applying structured concurrency to the contained suspendable tasks @@ -84,8 +97,8 @@ object sus4s { /** Forks a new concurrent task executing the given block of code and returning a [[Job]] that * completes with the value of type `A`. The task is executed in a new Virtual Thread using the - * given [[Suspend]] context. The job is transparent to any exception thrown by the `block`, which - * means it rethrows the exception. + * given [[Suspend]] context. The job is transparent to any exception thrown by the `block`, + * which means it rethrows the exception. * *

Example

* {{{ @@ -119,4 +132,50 @@ object sus4s { }) Job(result) } + + def forkCancellable[A](block: Suspend ?=> A): Suspend ?=> CancellableJob[A] = { + val result = new CompletableFuture[A]() + val innerResult = new CompletableFuture[A]() + val cancellableJob = CancellableJob(innerResult) + summon[Suspend].scope.fork(() => { + val innerScope = new StructuredTaskScope.ShutdownOnFailure() + cancellableJob.scope = innerScope + + given innerSuspended: Suspend = new Suspend { + override val scope: StructuredTaskScope[Any] = innerScope + } + try { + val subtask = innerScope.fork(() => { + val result = block(using innerSuspended) + innerResult.complete(result) + result + }) + try + innerResult.get() + catch + case e: Throwable => + innerScope.shutdown() + innerScope.join().throwIfFailed(identity) + if (subtask.state() == Subtask.State.UNAVAILABLE) { + // TODO Handle all cases + result.completeExceptionally(new InterruptedException("Job cancelled")) + } else { + result.complete(subtask.get()) + } + } catch + case exex: ExecutionException => + exex.getCause match + case ie: InterruptedException => innerScope.shutdown() + case e: Throwable => + result.completeExceptionally(e) + throw e + case e: Throwable => + result.completeExceptionally(e) + throw e + finally { + innerScope.close() + } + }) + cancellableJob + } } diff --git a/core/src/test/scala/StructuredSpec.scala b/core/src/test/scala/StructuredSpec.scala index 6fc15c5..934d638 100644 --- a/core/src/test/scala/StructuredSpec.scala +++ b/core/src/test/scala/StructuredSpec.scala @@ -1,5 +1,5 @@ import in.rcard.sus4s.sus4s -import in.rcard.sus4s.sus4s.{fork, structured} +import in.rcard.sus4s.sus4s.{fork, forkCancellable, structured} import org.scalatest.TryValues.* import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -99,4 +99,26 @@ class StructuredSpec extends AnyFlatSpec with Matchers { queue.toArray should contain theSameElementsInOrderAs List("job2", "job1") result shouldBe 85 } + + it should "cancel at the first suspending point" in { + val queue = new ConcurrentLinkedQueue[String]() + val result = structured { + val cancellable = forkCancellable { + while (true) { + Thread.sleep(2000) + println("cancellable job") + queue.add("cancellable") + } + } + val job = fork { + cancellable.cancel() + queue.add("job2") + 43 + } + job.value + } + + queue.toArray should contain theSameElementsInOrderAs List("job2") + result shouldBe 43 + } }