Skip to content

Commit

Permalink
First (very) rough version of cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
rcardin committed May 22, 2024
1 parent ba91067 commit f2a9631
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 5 deletions.
67 changes: 63 additions & 4 deletions core/src/main/scala/in/rcard/sus4s/sus4s.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package in.rcard.sus4s

import java.util.concurrent.{CompletableFuture, StructuredTaskScope}
import java.util.concurrent.StructuredTaskScope.Subtask
import java.util.concurrent.{CompletableFuture, ExecutionException, StructuredTaskScope}
import scala.compiletime.uninitialized

object sus4s {

Expand All @@ -21,7 +23,18 @@ object sus4s {
* The type of the value returned by the job
*/
class Job[A] private[sus4s] (private val cf: CompletableFuture[A]) {
def value: A = cf.join()
def value: A = cf.get()
}

class CancellableJob[A] private[sus4s] (val cf: CompletableFuture[A]) extends Job[A](cf) {
var scope: StructuredTaskScope[Any] = uninitialized
def cancel(): Unit = {
cf.completeExceptionally(new InterruptedException("Job cancelled"))
try
cf.get()
catch
case e =>
}
}

/** Executes a block of code applying structured concurrency to the contained suspendable tasks
Expand Down Expand Up @@ -84,8 +97,8 @@ object sus4s {

/** Forks a new concurrent task executing the given block of code and returning a [[Job]] that
* completes with the value of type `A`. The task is executed in a new Virtual Thread using the
* given [[Suspend]] context. The job is transparent to any exception thrown by the `block`, which
* means it rethrows the exception.
* given [[Suspend]] context. The job is transparent to any exception thrown by the `block`,
* which means it rethrows the exception.
*
* <h2>Example</h2>
* {{{
Expand Down Expand Up @@ -119,4 +132,50 @@ object sus4s {
})
Job(result)
}

def forkCancellable[A](block: Suspend ?=> A): Suspend ?=> CancellableJob[A] = {
val result = new CompletableFuture[A]()
val innerResult = new CompletableFuture[A]()
val cancellableJob = CancellableJob(innerResult)
summon[Suspend].scope.fork(() => {
val innerScope = new StructuredTaskScope.ShutdownOnFailure()
cancellableJob.scope = innerScope

given innerSuspended: Suspend = new Suspend {
override val scope: StructuredTaskScope[Any] = innerScope
}
try {
val subtask = innerScope.fork(() => {
val result = block(using innerSuspended)
innerResult.complete(result)
result
})
try
innerResult.get()
catch
case e: Throwable =>
innerScope.shutdown()
innerScope.join().throwIfFailed(identity)
if (subtask.state() == Subtask.State.UNAVAILABLE) {
// TODO Handle all cases
result.completeExceptionally(new InterruptedException("Job cancelled"))
} else {
result.complete(subtask.get())
}
} catch
case exex: ExecutionException =>
exex.getCause match
case ie: InterruptedException => innerScope.shutdown()
case e: Throwable =>
result.completeExceptionally(e)
throw e
case e: Throwable =>
result.completeExceptionally(e)
throw e
finally {
innerScope.close()
}
})
cancellableJob
}
}
24 changes: 23 additions & 1 deletion core/src/test/scala/StructuredSpec.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import in.rcard.sus4s.sus4s
import in.rcard.sus4s.sus4s.{fork, structured}
import in.rcard.sus4s.sus4s.{fork, forkCancellable, structured}
import org.scalatest.TryValues.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
Expand Down Expand Up @@ -99,4 +99,26 @@ class StructuredSpec extends AnyFlatSpec with Matchers {
queue.toArray should contain theSameElementsInOrderAs List("job2", "job1")
result shouldBe 85
}

it should "cancel at the first suspending point" in {
val queue = new ConcurrentLinkedQueue[String]()
val result = structured {
val cancellable = forkCancellable {
while (true) {
Thread.sleep(2000)
println("cancellable job")
queue.add("cancellable")
}
}
val job = fork {
cancellable.cancel()
queue.add("job2")
43
}
job.value
}

queue.toArray should contain theSameElementsInOrderAs List("job2")
result shouldBe 43
}
}

0 comments on commit f2a9631

Please sign in to comment.